diff --git a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/facealignment/Landmark.kt b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/facealignment/Landmark.kt index 1de3a5c4a..ec762b9da 100644 --- a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/facealignment/Landmark.kt +++ b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/facealignment/Landmark.kt @@ -5,5 +5,8 @@ package org.jetbrains.kotlinx.dl.api.inference.facealignment -/** Face landmark located on (xRate/xMaxSize, yRate/yMaxSize) point. */ -public data class Landmark(public val xRate: Float, public val yRate: Float) +/** + * Represents a face landmark as a point on the image with two coordinates relative to the top-left corner. + * Both coordinates have values between 0 and 1. + * */ +public data class Landmark(public val x: Float, public val y: Float) diff --git a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/objectdetection/DetectedObject.kt b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/objectdetection/DetectedObject.kt index 75fd96af1..f86495c19 100644 --- a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/objectdetection/DetectedObject.kt +++ b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/objectdetection/DetectedObject.kt @@ -8,18 +8,18 @@ package org.jetbrains.kotlinx.dl.api.inference.objectdetection /** * This data class represents the detected object on the given image. * - * @property [classLabel] The predicted class's name - * @property [probability] The probability of the predicted class. - * @property [xMax] The maximum X coordinate for the bounding box containing the predicted object. * @property [xMin] The minimum X coordinate for the bounding box containing the predicted object. - * @property [yMax] The maximum Y coordinate for the bounding box containing the predicted object. + * @property [xMax] The maximum X coordinate for the bounding box containing the predicted object. * @property [yMin] The minimum Y coordinate for the bounding box containing the predicted object. + * @property [yMax] The maximum Y coordinate for the bounding box containing the predicted object. + * @property [probability] The probability of the predicted class. + * @property [label] The predicted class's name */ public data class DetectedObject( - val classLabel: String, - val probability: Float, - val xMax: Float, val xMin: Float, + val xMax: Float, + val yMin: Float, val yMax: Float, - val yMin: Float + val probability: Float, + val label: String? = null ) diff --git a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/DetectedPose.kt b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/DetectedPose.kt index 77bb0dec6..329ca1286 100644 --- a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/DetectedPose.kt +++ b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/DetectedPose.kt @@ -8,10 +8,10 @@ package org.jetbrains.kotlinx.dl.api.inference.posedetection /** * This data class represents the human's pose detected on the given image. * - * @property [poseLandmarks] The list of detected [PoseLandmark]s for the given image. + * @property [landmarks] The list of detected [PoseLandmark]s for the given image. * @property [edges] The list of edges connecting the detected [PoseLandmark]s. */ public data class DetectedPose( - val poseLandmarks: List, + val landmarks: List, val edges: List ) diff --git a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/MultiPoseDetectionResult.kt b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/MultiPoseDetectionResult.kt index df3c170f9..4ee96416f 100644 --- a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/MultiPoseDetectionResult.kt +++ b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/MultiPoseDetectionResult.kt @@ -10,5 +10,5 @@ import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject /** This data class represents a few detected poses on the given image. */ public data class MultiPoseDetectionResult( /** The list of pairs DetectedObject - DetectedPose. */ - val multiplePoses: List> + val poses: List> ) diff --git a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseEdge.kt b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseEdge.kt index 323bba3e1..cfeeb894d 100644 --- a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseEdge.kt +++ b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseEdge.kt @@ -3,14 +3,14 @@ package org.jetbrains.kotlinx.dl.api.inference.posedetection /** * This data class represents the line connecting two points [PoseLandmark] of human's pose. * - * @property [poseEdgeLabel] The predicted pose edge label. + * @property [label] The predicted pose edge label. * @property [probability] The probability of the predicted class. * @property [start] The probability of the predicted class. * @property [end] The probability of the predicted class. */ public data class PoseEdge( - val poseEdgeLabel: String, - val probability: Float, val start: PoseLandmark, val end: PoseLandmark, + val probability: Float, + val label: String, ) diff --git a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseLandmark.kt b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseLandmark.kt index af62afa21..c78ff6a07 100644 --- a/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseLandmark.kt +++ b/api/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/posedetection/PoseLandmark.kt @@ -8,14 +8,14 @@ package org.jetbrains.kotlinx.dl.api.inference.posedetection /** * This data class represents one point of the detected human's pose. * - * @property [poseLandmarkLabel] The predicted pose landmark label. - * @property [probability] The probability of the predicted class. * @property [x] The value of `x` coordinate. * @property [y] The value of `y` coordinate. + * @property [probability] The probability of the predicted class. + * @property [label] The predicted pose landmark label. */ public data class PoseLandmark( - val poseLandmarkLabel: String, - val probability: Float, val x: Float, val y: Float, + val probability: Float, + val label: String, ) diff --git a/examples/src/main/kotlin/examples/onnx/faces/FaceDetectionWithVisualization.kt b/examples/src/main/kotlin/examples/onnx/faces/FaceDetectionWithVisualization.kt new file mode 100644 index 000000000..f57eac150 --- /dev/null +++ b/examples/src/main/kotlin/examples/onnx/faces/FaceDetectionWithVisualization.kt @@ -0,0 +1,37 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package examples.onnx.faces + +import examples.transferlearning.getFileFromResource +import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub +import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider +import org.jetbrains.kotlinx.dl.api.inference.onnx.inferAndCloseUsing +import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter +import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline +import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.resize +import org.jetbrains.kotlinx.dl.visualization.swing.createDetectedObjectsPanel +import org.jetbrains.kotlinx.dl.visualization.swing.showFrame +import java.awt.image.BufferedImage +import java.io.File + +fun main() { + val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels")) + val model = ONNXModels.FaceDetection.UltraFace320.pretrainedModel(modelHub) + + model.inferAndCloseUsing(ExecutionProvider.CPU()) { + val file = getFileFromResource("datasets/poses/multi/1.jpg") + val image = ImageConverter.toBufferedImage(file) + val faces = it.detectFaces(image) + + val width = 600 + val resize = pipeline().resize { + outputWidth = width + outputHeight = width * image.height / image.width + } + showFrame("Detected Faces", createDetectedObjectsPanel(resize.apply(image), faces)) + } +} \ No newline at end of file diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPI.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPI.kt index 66ace6e25..60ed8f5b9 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPI.kt @@ -23,7 +23,7 @@ fun main() { detectionModel.detectObjects(imageFile = imageFile) detectedObjects.forEach { - println("Found ${it.classLabel} with probability ${it.probability}") + println("Found ${it.label} with probability ${it.probability}") } } } diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualization.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualization.kt index e18468d4f..a442bac56 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualization.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualization.kt @@ -25,7 +25,7 @@ fun main() { val detectedObjects = detectionModel.detectObjects(image) detectedObjects.forEach { - println("Found ${it.classLabel} with probability ${it.probability}") + println("Found ${it.label} with probability ${it.probability}") } showFrame("Detection result for ${file.name}", createDetectedObjectsPanel(image, detectedObjects)) diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualizationAndInputShape.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualizationAndInputShape.kt index 9d0804425..f98980716 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualizationAndInputShape.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/efficientdet/EfficientDetD2LightAPIwithVisualizationAndInputShape.kt @@ -26,7 +26,7 @@ fun main() { val detectedObjects = detectionModel.detectObjects(image) detectedObjects.forEach { - println("Found ${it.classLabel} with probability ${it.probability}") + println("Found ${it.label} with probability ${it.probability}") } showFrame("Detection result for ${file.name}", createDetectedObjectsPanel(image, detectedObjects)) diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPI.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPI.kt index c92092281..41d619d4e 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPI.kt @@ -29,7 +29,7 @@ fun ssdLightAPI() { detectionModel.detectObjects(imageFile = imageFile, topK = 50) detectedObjects.forEach { - println("Found ${it.classLabel} with probability ${it.probability}") + println("Found ${it.label} with probability ${it.probability}") } } } diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPIWithVisualization.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPIWithVisualization.kt index d9cf2bbb4..8a7e83e64 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPIWithVisualization.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSDLightAPIWithVisualization.kt @@ -36,7 +36,7 @@ fun main() { val detectedObjects = detectionModel.detectObjects(image, topK = 20) detectedObjects.forEach { - println("Found ${it.classLabel} with probability ${it.probability}") + println("Found ${it.label} with probability ${it.probability}") } val displayedImage = pipeline() diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPI.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPI.kt index 663ba8124..ae233a0f6 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPI.kt @@ -30,7 +30,7 @@ fun ssdMobileLightAPI() { detectionModel.detectObjects(imageFile = imageFile, topK = 50) detectedObjects.forEach { - println("Found ${it.classLabel} with probability ${it.probability}") + println("Found ${it.label} with probability ${it.probability}") } } } diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPIwithVisualization.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPIwithVisualization.kt index ea26c9dbb..00c4df8c8 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPIwithVisualization.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/ssdmobile/SSDMobileLightAPIwithVisualization.kt @@ -9,15 +9,9 @@ import examples.transferlearning.getFileFromResource import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDObjectDetectionModel -import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline -import org.jetbrains.kotlinx.dl.dataset.preprocessing.rescale -import org.jetbrains.kotlinx.dl.dataset.preprocessor.fileLoader -import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.convert import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.resize -import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.toFloatArray -import org.jetbrains.kotlinx.dl.dataset.preprocessor.toImageShape import org.jetbrains.kotlinx.dl.visualization.swing.createDetectedObjectsPanel import org.jetbrains.kotlinx.dl.visualization.swing.showFrame import java.awt.image.BufferedImage @@ -41,7 +35,7 @@ fun main() { val detectedObjects = detectionModel.detectObjects(image, topK = 50) detectedObjects.forEach { - println("Found ${it.classLabel} with probability ${it.probability}") + println("Found ${it.label} with probability ${it.probability}") } val displayedImage = pipeline() diff --git a/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt b/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt index 2de6006ce..a04ebf918 100644 --- a/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt +++ b/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt @@ -63,22 +63,21 @@ fun multiPoseDetectionMoveNet() { for (keyPointIdx in 0..16) { val poseLandmark = PoseLandmark( - poseLandmarkLabel = "", x = floats[3 * keyPointIdx + 1], y = floats[3 * keyPointIdx], - probability = floats[3 * keyPointIdx + 2] + probability = floats[3 * keyPointIdx + 2], + label = "" ) foundPoseLandmarks.add(poseLandmark) } // [ymin, xmin, ymax, xmax, score] val detectedObject = DetectedObject( - classLabel = "person", - probability = probability, - yMin = floats[51], xMin = floats[52], + xMax = floats[54], + yMin = floats[51], yMax = floats[53], - xMax = floats[54] + probability = probability ) val detectedPose = DetectedPose(foundPoseLandmarks, emptyList()) diff --git a/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNetLightAPI.kt b/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNetLightAPI.kt index 78c5c825d..83e81acb4 100644 --- a/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNetLightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNetLightAPI.kt @@ -35,14 +35,14 @@ fun multiPoseDetectionMoveNetLightAPI() { val image = ImageConverter.toBufferedImage(getFileFromResource("datasets/poses/multi/$i.jpg")) val detectedPoses = poseDetectionModel.detectPoses(image = image, confidence = 0.05f) - detectedPoses.multiplePoses.forEach { detectedPose -> - println("Found ${detectedPose.first.classLabel} with probability ${detectedPose.first.probability}") - detectedPose.second.poseLandmarks.forEach { - println(" Found ${it.poseLandmarkLabel} with probability ${it.probability}") + detectedPoses.poses.forEach { detectedPose -> + println("Found ${detectedPose.first.label} with probability ${detectedPose.first.probability}") + detectedPose.second.landmarks.forEach { + println(" Found ${it.label} with probability ${it.probability}") } detectedPose.second.edges.forEach { - println(" The ${it.poseEdgeLabel} starts at ${it.start.poseLandmarkLabel} and ends with ${it.end.poseLandmarkLabel}") + println(" The ${it.label} starts at ${it.start.label} and ends with ${it.end.label}") } } result[image] = detectedPoses diff --git a/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/PoseDetectionMoveNetLightAPI.kt b/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/PoseDetectionMoveNetLightAPI.kt index 51a258992..ffa0e2187 100644 --- a/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/PoseDetectionMoveNetLightAPI.kt +++ b/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/PoseDetectionMoveNetLightAPI.kt @@ -36,12 +36,12 @@ fun poseDetectionMoveNetLightAPI() { val image = ImageConverter.toBufferedImage(file) val detectedPose = poseDetectionModel.detectPose(image) - detectedPose.poseLandmarks.forEach { - println("Found ${it.poseLandmarkLabel} with probability ${it.probability}") + detectedPose.landmarks.forEach { + println("Found ${it.label} with probability ${it.probability}") } detectedPose.edges.forEach { - println("The ${it.poseEdgeLabel} starts at ${it.start.poseLandmarkLabel} and ends with ${it.end.poseLandmarkLabel}") + println("The ${it.label} starts at ${it.start.label} and ends with ${it.end.label}") } result[image] = detectedPose diff --git a/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt b/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt index 2db63d178..b0cc89cb3 100644 --- a/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt +++ b/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt @@ -82,10 +82,10 @@ fun poseDetectionMoveNet() { val foundPoseLandmarks = mutableListOf() for (i in rawPoseLandMarks.indices) { val poseLandmark = PoseLandmark( - poseLandmarkLabel = keypoints[i]!!, x = rawPoseLandMarks[i][1], y = rawPoseLandMarks[i][0], - probability = rawPoseLandMarks[i][2] + probability = rawPoseLandMarks[i][2], + label = keypoints[i]!! ) foundPoseLandmarks.add(i, poseLandmark) } diff --git a/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt b/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt index e6c3783de..154cf02ba 100644 --- a/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt +++ b/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt @@ -29,7 +29,7 @@ class PoseDetectionTestSuite { model.use { poseDetectionModel -> val imageFile = getFileFromResource("datasets/poses/single/1.jpg") val detectedPose = poseDetectionModel.detectPose(imageFile = imageFile) - assertEquals(17, detectedPose.poseLandmarks.size) + assertEquals(17, detectedPose.landmarks.size) assertEquals(18, detectedPose.edges.size) } } @@ -42,7 +42,7 @@ class PoseDetectionTestSuite { model.use { poseDetectionModel -> val imageFile = getFileFromResource("datasets/poses/single/1.jpg") val detectedPose = poseDetectionModel.detectPose(imageFile = imageFile) - assertEquals(17, detectedPose.poseLandmarks.size) + assertEquals(17, detectedPose.landmarks.size) assertEquals(18, detectedPose.edges.size) } } @@ -55,9 +55,9 @@ class PoseDetectionTestSuite { model.use { poseDetectionModel -> val imageFile = getFileFromResource("datasets/poses/multi/1.jpg") val detectedPoses = poseDetectionModel.detectPoses(imageFile = imageFile) - assertEquals(3, detectedPoses.multiplePoses.size) - detectedPoses.multiplePoses.forEach { - assertEquals(17, it.second.poseLandmarks.size) + assertEquals(3, detectedPoses.poses.size) + detectedPoses.poses.forEach { + assertEquals(17, it.second.landmarks.size) assertEquals(18, it.second.edges.size) } } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt index 8bdbd212b..0fca69637 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt @@ -4,6 +4,8 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.InputType import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelHub import org.jetbrains.kotlinx.dl.api.inference.onnx.classification.ImageRecognitionModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.Fan2D106FaceAlignmentModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.FaceDetectionModel import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModel import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModelMetadata import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.SinglePoseDetectionModel @@ -235,4 +237,67 @@ public object ONNXModels { } } } + + /** Face detection models */ + public sealed class FaceDetection(override val inputShape: LongArray, override val modelRelativePath: String) : + OnnxModelType { + override val preprocessor: Operation, Pair> + get() = defaultPreprocessor + + override fun pretrainedModel(modelHub: ModelHub): FaceDetectionModel { + return FaceDetectionModel(modelHub.loadModel(this)) + } + + /** + * Ultra-lightweight face detection model. + * + * Model accepts input of the shape (1 x 3 x 240 x 320) + * Model outputs two arrays (1 x 4420 x 2) and (1 x 4420 x 4) of scores and boxes. + * + * Threshold filtration and non-max suppression are applied during postprocessing. + * + * @see Ultra-lightweight face detection model + */ + public object UltraFace320 : FaceDetection(longArrayOf(3L, 240, 320), "ultraface_320") + + /** + * Ultra-lightweight face detection model. + * + * Model accepts input of the shape (1 x 3 x 480 x 640) + * Model outputs two arrays (1 x 4420 x 2) and (1 x 4420 x 4) of scores and boxes. + * + * Threshold filtration and non-max suppression are applied during postprocessing. + * + * @see Ultra-lightweight face detection model + */ + public object UltraFace640 : FaceDetection(longArrayOf(3L, 480, 640), "ultraface_640") + + public companion object { + public val defaultPreprocessor: Operation, Pair> = + pipeline>() + .normalize { + mean = floatArrayOf(127f, 127f, 127f) + std = floatArrayOf(128f, 128f, 128f) + channelsLast = false + } + } + } + + /** Face alignment models */ + public sealed class FaceAlignment : OnnxModelType { + /** + * This model is a neural network for face alignment that take RGB images of faces as input and produces coordinates of 106 faces landmarks. + * + * The model have + * - an input with the shape (1x3x192x192) + * - an output with the shape (1x212) + */ + public object Fan2d106 : FaceAlignment() { + override val inputShape: LongArray = longArrayOf(3L, 192, 192) + override val modelRelativePath: String = "fan_2d_106" + override fun pretrainedModel(modelHub: ModelHub): Fan2D106FaceAlignmentModel { + return Fan2D106FaceAlignmentModel(modelHub.loadModel(this)) + } + } + } } diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModel.kt new file mode 100644 index 000000000..b751df93c --- /dev/null +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModel.kt @@ -0,0 +1,58 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment + +import android.graphics.Bitmap +import androidx.camera.core.ImageProxy +import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject +import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation +import org.jetbrains.kotlinx.dl.dataset.preprocessing.* +import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap +import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape + +/** + * Face detection model implementation. + * + * @see ONNXModels.FaceDetection.UltraFace320 + * @see ONNXModels.FaceDetection.UltraFace640 + */ +public class FaceDetectionModel(override val internalModel: OnnxInferenceModel) : FaceDetectionModelBase(), + CameraXCompatibleModel, InferenceModel by internalModel { + override var targetRotation: Int = 0 + override val preprocessing: Operation> + get() = pipeline() + .rotate { degrees = targetRotation.toFloat() } + .resize { + outputWidth = internalModel.inputDimensions[2].toInt() + outputHeight = internalModel.inputDimensions[1].toInt() + } + .toFloatArray { layout = TensorLayout.NCHW } + .call(ONNXModels.FaceDetection.defaultPreprocessor) + + override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel { + return FaceDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights)) + } +} + +/** + * Detects [topK] faces on the given [imageProxy]. If [topK] is negative all detected faces are returned. + * @param [iouThreshold] threshold IoU value for the non-maximum suppression applied during postprocessing + */ +public fun FaceDetectionModelBase.detectFaces(imageProxy: ImageProxy, + topK: Int = 5, + iouThreshold: Float = 0.5f +): List { + if (this is CameraXCompatibleModel) { + return doWithRotation(imageProxy.imageInfo.rotationDegrees) { + detectFaces(imageProxy.toBitmap(), topK, iouThreshold) + } + } + return detectFaces(imageProxy.toBitmap(applyRotation = true), topK, iouThreshold) +} \ No newline at end of file diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt new file mode 100644 index 000000000..b1c53a30c --- /dev/null +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt @@ -0,0 +1,54 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment + +import android.graphics.Bitmap +import androidx.camera.core.ImageProxy +import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark +import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation +import org.jetbrains.kotlinx.dl.dataset.preprocessing.* +import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap +import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape + +/** + * The light-weight API for solving Face Alignment task. + * + * @param [internalModel] model used to make predictions + */ +public class Fan2D106FaceAlignmentModel(override val internalModel: OnnxInferenceModel) : + FaceAlignmentModelBase(), + CameraXCompatibleModel, InferenceModel by internalModel { + + override val outputName: String = "fc1" + override var targetRotation: Int = 0 + + override val preprocessing: Operation> = pipeline() + .resize { + outputWidth = internalModel.inputDimensions[2].toInt() + outputHeight = internalModel.inputDimensions[1].toInt() + } + .rotate { degrees = targetRotation.toFloat() } + .toFloatArray { layout = TensorLayout.NCHW } + + override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel { + return Fan2D106FaceAlignmentModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights)) + } +} + +/** + * Detects [Landmark] objects on the given [imageProxy]. + */ +public fun FaceAlignmentModelBase.detectLandmarks(imageProxy: ImageProxy): List { + if (this is CameraXCompatibleModel) { + return doWithRotation(imageProxy.imageInfo.rotationDegrees) { + detectLandmarks(imageProxy.toBitmap()) + } + } + return detectLandmarks(imageProxy.toBitmap(applyRotation = true)) +} \ No newline at end of file diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModelBase.kt new file mode 100644 index 000000000..400716be2 --- /dev/null +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModelBase.kt @@ -0,0 +1,85 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment + +import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject +import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxHighLevelModel +import java.lang.Float.min +import kotlin.math.max + +/** + * Base class for face detection models. + */ +public abstract class FaceDetectionModelBase : OnnxHighLevelModel> { + + override fun convert(output: Map): List { + val scores = (output["scores"] as Array<*>)[0] as Array + val boxes = (output["boxes"] as Array<*>)[0] as Array + + if (scores.isEmpty()) return emptyList() + + val result = mutableListOf() + for (classIndex in 1 until scores[0].size) { + for ((box, classScores) in boxes.zip(scores)) { + val score = classScores[classIndex] + if (score > THRESHOLD) { + result.add(DetectedObject(box[0], box[2], box[1], box[3], score)) + } + } + } + return result + } + + /** + * Detects [topK] faces on the given [image]. If [topK] is negative all detected faces are returned. + * @param [iouThreshold] threshold IoU value for the non-maximum suppression applied during postprocessing + */ + public fun detectFaces(image: I, topK: Int = 5, iouThreshold: Float = 0.5f): List { + val detectedObjects = predict(image) + return suppressNonMaxBoxes(detectedObjects, topK, iouThreshold) + } + + public companion object { + private const val THRESHOLD = 0.7 + private const val EPS = Float.MIN_VALUE + + /** + * Performs non-maximum suppression to filter out boxes with the IoU greater than threshold. + * @param [boxes] boxes to filter + * @param [topK] how many boxes to include in the result. Negative or zero means to include everything. + * @param [threshold] threshold IoU value + */ + public fun suppressNonMaxBoxes(boxes: List, + topK: Int = -1, + threshold: Float = 0.5f + ): List { + val sortedBoxes = boxes.toMutableList().apply { sortByDescending { it.probability } } + val result = mutableListOf() + while (sortedBoxes.isNotEmpty()) { + val box = sortedBoxes.removeFirst() + result.add(box) + if (topK > 0 && result.size >= topK) break + + sortedBoxes.removeIf { iou(box, it) >= threshold } + } + return result + } + + /** + * Computes the intersection over union value for the [box1] and [box2]. + */ + public fun iou(box1: DetectedObject, box2: DetectedObject): Float { + val xMin = max(box1.xMin, box2.xMin) + val yMin = max(box1.yMin, box2.yMin) + val xMax = min(box1.xMax, box2.xMax) + val yMax = min(box1.yMax, box2.yMax) + val overlap = (xMax - xMin) * (yMax - yMin) + return overlap / (box1.area() + box2.area() - overlap + EPS) + } + + private fun DetectedObject.area() = (xMax - xMin) * (yMax - yMin) + } +} \ No newline at end of file diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt index b4547a42f..bf7cf8522 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt @@ -46,13 +46,13 @@ public abstract class EfficientDetObjectDetectionModelBase : ObjectDetectionM val probability = items[i][5] if (probability != 0.0f) { val detectedObject = DetectedObject( - classLabel = classLabels[items[i][6].toInt()]!!, - probability = probability, - // left, bot, right, top xMin = minOf(items[i][2] / internalModel.inputDimensions[1], 1.0f), - yMax = minOf(items[i][3] / internalModel.inputDimensions[0], 1.0f), xMax = minOf(items[i][4] / internalModel.inputDimensions[1], 1.0f), - yMin = minOf(items[i][1] / internalModel.inputDimensions[0], 1.0f) + // left, bot, right, top + yMin = minOf(items[i][1] / internalModel.inputDimensions[0], 1.0f), + yMax = minOf(items[i][3] / internalModel.inputDimensions[0], 1.0f), + probability = probability, + label = classLabels[items[i][6].toInt()] ) foundObjects.add(detectedObject) } @@ -78,13 +78,13 @@ public abstract class SSDLikeModelBase(protected val metadata: SSDLikeModelMe val foundObjects = mutableListOf() for (i in 0 until numberOfFoundObjects) { val detectedObject = DetectedObject( - classLabel = classLabels[classIndices[i].toInt()] ?: "Unknown", - probability = probabilities[i], - // left, bot, right, top xMin = boxes[i][metadata.xMinIdx], - yMin = boxes[i][metadata.yMinIdx], xMax = boxes[i][metadata.xMinIdx + 2], - yMax = boxes[i][metadata.yMinIdx + 2] + // left, bot, right, top + yMin = boxes[i][metadata.yMinIdx], + yMax = boxes[i][metadata.yMinIdx + 2], + probability = probabilities[i], + label = classLabels[classIndices[i].toInt()] ) foundObjects.add(detectedObject) } diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt index 4488adc4f..be4cccd7c 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt @@ -38,22 +38,21 @@ public abstract class MultiPoseDetectionModelBase : OnnxHighLevelModel : OnnxHighLevelModel + val filteredPoses = result.poses.filter { (detectedObject, _) -> detectedObject.probability > confidence } return MultiPoseDetectionResult(filteredPoses) } - - private companion object { - private const val CLASS_LABEL = "person" - } } \ No newline at end of file diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt index ccc8a8825..57c54fc2a 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt @@ -39,7 +39,7 @@ public abstract class SinglePoseDetectionModelBase : OnnxHighLevelModel() for (i in rawPoseLandMarks.indices) { val poseLandmark = PoseLandmark( - poseLandmarkLabel = keyPointsLabels[i]!!, + label = keyPointsLabels[i]!!, x = rawPoseLandMarks[i][1], y = rawPoseLandMarks[i][0], probability = rawPoseLandMarks[i][2] @@ -66,10 +66,10 @@ internal fun buildPoseEdges(foundPoseLandmarks: List, edgeKeyPoint val endPoint = foundPoseLandmarks[it.second] foundPoseEdges.add( PoseEdge( - poseEdgeLabel = startPoint.poseLandmarkLabel + "_" + endPoint.poseLandmarkLabel, - probability = min(startPoint.probability, endPoint.probability), start = startPoint, - end = endPoint + end = endPoint, + probability = min(startPoint.probability, endPoint.probability), + label = startPoint.label + "_" + endPoint.label ) ) } diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt index 511424139..ee4bbf232 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/ONNXModels.kt @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModel import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.InputType import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelHub +import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.FaceDetectionModel import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.Fan2D106FaceAlignmentModel import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.EfficientDetObjectDetectionModel import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDMobileNetV1ObjectDetectionModel @@ -20,6 +21,7 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.SinglePoseDetec import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation import org.jetbrains.kotlinx.dl.dataset.preprocessing.call +import org.jetbrains.kotlinx.dl.dataset.preprocessing.normalize import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape @@ -827,6 +829,52 @@ public object ONNXModels { } } + /** Face detection models */ + public sealed class FaceDetection(override val inputShape: LongArray, modelName: String) : + OnnxModelType { + override val modelRelativePath: String = "models/onnx/facealignment/$modelName" + override val preprocessor: Operation, Pair> + get() = defaultPreprocessor + + override fun pretrainedModel(modelHub: ModelHub): FaceDetectionModel { + return FaceDetectionModel(modelHub.loadModel(this)) + } + + /** + * Ultra-lightweight face detection model. + * + * Model accepts input of the shape (1 x 3 x 240 x 320) + * Model outputs two arrays (1 x 4420 x 2) and (1 x 4420 x 4) of scores and boxes. + * + * Threshold filtration and non-max suppression are applied during postprocessing. + * + * @see Ultra-lightweight face detection model + */ + public object UltraFace320 : FaceDetection(longArrayOf(3L, 240, 320), "ultraface_320") + + /** + * Ultra-lightweight face detection model. + * + * Model accepts input of the shape (1 x 3 x 480 x 640) + * Model outputs two arrays (1 x 4420 x 2) and (1 x 4420 x 4) of scores and boxes. + * + * Threshold filtration and non-max suppression are applied during postprocessing. + * + * @see Ultra-lightweight face detection model + */ + public object UltraFace640 : FaceDetection(longArrayOf(3L, 480, 640), "ultraface_640") + + public companion object { + public val defaultPreprocessor: Operation, Pair> = + pipeline>() + .normalize { + mean = floatArrayOf(127f, 127f, 127f) + std = floatArrayOf(128f, 128f, 128f) + } + .transpose { axes = intArrayOf(2, 0, 1) } + } + } + /** Face alignment models and preprocessing. */ public sealed class FaceAlignment(override val modelRelativePath: String) : OnnxModelType { @@ -839,6 +887,7 @@ public object ONNXModels { */ public object Fan2d106 : FaceAlignment("models/onnx/facealignment/fan_2d_106") { + override val inputShape: LongArray = longArrayOf(3L, 192L, 192L) override val preprocessor: Operation, Pair> get() = Transpose(axes = intArrayOf(2, 0, 1)) diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModel.kt new file mode 100644 index 000000000..821e1d553 --- /dev/null +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModel.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment + +import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.dataset.image.ColorMode +import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation +import org.jetbrains.kotlinx.dl.dataset.preprocessing.call +import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline +import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.convert +import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.resize +import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.toFloatArray +import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape +import java.awt.image.BufferedImage + +/** + * Face detection model implementation. + * + * @see ONNXModels.FaceDetection.UltraFace320 + * @see ONNXModels.FaceDetection.UltraFace640 + */ +public class FaceDetectionModel(override val internalModel: OnnxInferenceModel) : + FaceDetectionModelBase(), InferenceModel by internalModel { + override val preprocessing: Operation> + get() = pipeline() + .resize { + outputWidth = internalModel.inputDimensions[2].toInt() + outputHeight = internalModel.inputDimensions[1].toInt() + } + .convert { colorMode = ColorMode.RGB } + .toFloatArray { } + .call(ONNXModels.FaceDetection.defaultPreprocessor) + + override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel { + return FaceDetectionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights)) + } +} \ No newline at end of file diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt index d3cc8c81d..327d02128 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt @@ -23,7 +23,6 @@ import java.io.File import java.io.IOException private const val OUTPUT_NAME = "fc1" -private const val INPUT_SIZE = 192 /** * The light-weight API for solving Face Alignment task via Fan2D106 model. @@ -36,8 +35,8 @@ public class Fan2D106FaceAlignmentModel(override val internalModel: OnnxInferenc override val preprocessing: Operation> get() = pipeline() .resize { - outputHeight = INPUT_SIZE - outputWidth = INPUT_SIZE + outputWidth = internalModel.inputDimensions[2].toInt() + outputHeight = internalModel.inputDimensions[1].toInt() } .convert { colorMode = ColorMode.BGR } .toFloatArray {} diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt index b6b51ea96..c134ed89c 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt @@ -89,13 +89,13 @@ public class SSDObjectDetectionModel(override val internalModel: OnnxInferenceMo val foundObjects = mutableListOf() for (i in 0 until numberOfFoundObjects) { val detectedObject = DetectedObject( - classLabel = classLabels[classIndices[i].toInt()] ?: "Unknown", - probability = probabilities[i], - // left, bot, right, top xMin = boxes[i][metadata.xMinIdx], - yMin = boxes[i][metadata.yMinIdx], xMax = boxes[i][metadata.xMinIdx + 2], - yMax = boxes[i][metadata.yMinIdx + 2] + // left, bot, right, top + yMin = boxes[i][metadata.yMinIdx], + yMax = boxes[i][metadata.yMinIdx + 2], + probability = probabilities[i], + label = classLabels[classIndices[i].toInt()] ) foundObjects.add(detectedObject) } diff --git a/visualization/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/visualization/DrawDetectionResults.kt b/visualization/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/visualization/DrawDetectionResults.kt index 58b1df5f4..f75e66367 100644 --- a/visualization/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/visualization/DrawDetectionResults.kt +++ b/visualization/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/visualization/DrawDetectionResults.kt @@ -36,8 +36,10 @@ fun Canvas.drawObject( drawRect(rect, Paint(paint).apply { strokeWidth = frameWidth }) - val label = "${detectedObject.classLabel} : " + "%.2f".format(detectedObject.probability) - drawText(label, rect.left, rect.top - labelPaint.fontMetrics.descent - frameWidth / 2, labelPaint) + if (detectedObject.label != null) { + val label = "${detectedObject.label} : " + "%.2f".format(detectedObject.probability) + drawText(label, rect.left, rect.top - labelPaint.fontMetrics.descent - frameWidth / 2, labelPaint) + } } /** @@ -80,7 +82,7 @@ fun Canvas.drawPose( ) } - detectedPose.poseLandmarks.forEach { landmark -> + detectedPose.landmarks.forEach { landmark -> drawCircle(bounds.toViewX(landmark.x), bounds.toViewY(landmark.y), landmarkRadius, landmarkPaint) } } @@ -103,7 +105,7 @@ fun Canvas.drawMultiplePoses( landmarkRadius: Float, bounds: PreviewImageBounds = bounds() ) { - detectedPoses.multiplePoses.forEach { (detectedObject, detectedPose) -> + detectedPoses.poses.forEach { (detectedObject, detectedPose) -> drawPose(detectedPose, landmarkPaint, edgePaint, landmarkRadius, bounds) drawObject(detectedObject, objectPaint, labelPaint, bounds) } @@ -122,7 +124,7 @@ fun Canvas.drawLandmarks(landmarks: List, bounds: PreviewImageBounds = bounds() ) { landmarks.forEach { landmark -> - drawCircle(bounds.toViewX(landmark.xRate), bounds.toViewY(landmark.yRate), radius, paint) + drawCircle(bounds.toViewX(landmark.x), bounds.toViewY(landmark.y), radius, paint) } } diff --git a/visualization/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/visualization/swing/PlotDetectionResults.kt b/visualization/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/visualization/swing/PlotDetectionResults.kt index a30c6b8c0..09ff7f187 100644 --- a/visualization/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/visualization/swing/PlotDetectionResults.kt +++ b/visualization/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/visualization/swing/PlotDetectionResults.kt @@ -68,10 +68,12 @@ private fun Graphics2D.drawObject(detectedObject: DetectedObject, stroke = BasicStroke(frameWidth) draw(Rectangle2D.Float(x, y, detectedObject.xMax * width - x, detectedObject.yMax * height - y)) - val label = "${detectedObject.classLabel} : " + "%.2f".format(detectedObject.probability) - color = labelColor - font = font.deriveFont(Font.BOLD) - drawString(label, x, y - fontMetrics.maxDescent - frameWidth / 2) + if (detectedObject.label != null) { + val label = "${detectedObject.label} : " + "%.2f".format(detectedObject.probability) + color = labelColor + font = font.deriveFont(Font.BOLD) + drawString(label, x, y - fontMetrics.maxDescent - frameWidth / 2) + } } private fun Graphics2D.drawObjects(detectedObjects: List, width: Int, height: Int) { @@ -98,7 +100,7 @@ private fun Graphics2D.drawPose(detectedPose: DetectedPose, width: Int, height: val r = 3.0f color = landmarkColor - detectedPose.poseLandmarks.forEach { landmark -> + detectedPose.landmarks.forEach { landmark -> fill(Ellipse2D.Float(width * landmark.x - r, height * landmark.y - r, 2 * r, 2 * r)) } } @@ -108,7 +110,7 @@ private fun Graphics2D.drawMultiplePoses(multiPoseDetectionResult1: MultiPoseDet height: Int ) { setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON) - multiPoseDetectionResult1.multiplePoses.forEachIndexed { i, (detectedObject, detectedPose) -> + multiPoseDetectionResult1.poses.forEachIndexed { i, (detectedObject, detectedPose) -> drawPose( detectedPose, width, height, Color((6 - i) * 40, i * 20, i * 10), @@ -125,6 +127,6 @@ private fun Graphics2D.drawLandmarks(landmarks: List, width: Int, heig val r = 3.0f color = Color.RED landmarks.forEach { landmark -> - fill(Ellipse2D.Float(width * landmark.xRate - r, height * landmark.yRate - r, 2 * r, 2 * r)) + fill(Ellipse2D.Float(width * landmark.x - r, height * landmark.y - r, 2 * r, 2 * r)) } } \ No newline at end of file