Skip to content

Commit

Permalink
Add face detection models and face alignment model for android (#461)
Browse files Browse the repository at this point in the history
* Unify detected objects property names and reorder constructor parameters

* Allow unlabeled DetectedObject instances

* Add face detection models

* Add face alignment model for android

* fixup! Add face detection models

Add a method working on ImageProxy, make NMS public.

* fixup! Add face alignment model for android
  • Loading branch information
juliabeliaeva authored Oct 3, 2022
1 parent 9393809 commit 484b0e8
Show file tree
Hide file tree
Showing 32 changed files with 483 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@

package org.jetbrains.kotlinx.dl.api.inference.facealignment

/** Face landmark located on (xRate/xMaxSize, yRate/yMaxSize) point. */
public data class Landmark(public val xRate: Float, public val yRate: Float)
/**
* Represents a face landmark as a point on the image with two coordinates relative to the top-left corner.
* Both coordinates have values between 0 and 1.
* */
public data class Landmark(public val x: Float, public val y: Float)
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ package org.jetbrains.kotlinx.dl.api.inference.objectdetection
/**
* This data class represents the detected object on the given image.
*
* @property [classLabel] The predicted class's name
* @property [probability] The probability of the predicted class.
* @property [xMax] The maximum X coordinate for the bounding box containing the predicted object.
* @property [xMin] The minimum X coordinate for the bounding box containing the predicted object.
* @property [yMax] The maximum Y coordinate for the bounding box containing the predicted object.
* @property [xMax] The maximum X coordinate for the bounding box containing the predicted object.
* @property [yMin] The minimum Y coordinate for the bounding box containing the predicted object.
* @property [yMax] The maximum Y coordinate for the bounding box containing the predicted object.
* @property [probability] The probability of the predicted class.
* @property [label] The predicted class's name
*/
public data class DetectedObject(
val classLabel: String,
val probability: Float,
val xMax: Float,
val xMin: Float,
val xMax: Float,
val yMin: Float,
val yMax: Float,
val yMin: Float
val probability: Float,
val label: String? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ package org.jetbrains.kotlinx.dl.api.inference.posedetection
/**
* This data class represents the human's pose detected on the given image.
*
* @property [poseLandmarks] The list of detected [PoseLandmark]s for the given image.
* @property [landmarks] The list of detected [PoseLandmark]s for the given image.
* @property [edges] The list of edges connecting the detected [PoseLandmark]s.
*/
public data class DetectedPose(
val poseLandmarks: List<PoseLandmark>,
val landmarks: List<PoseLandmark>,
val edges: List<PoseEdge>
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
/** This data class represents a few detected poses on the given image. */
public data class MultiPoseDetectionResult(
/** The list of pairs DetectedObject - DetectedPose. */
val multiplePoses: List<Pair<DetectedObject, DetectedPose>>
val poses: List<Pair<DetectedObject, DetectedPose>>
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ package org.jetbrains.kotlinx.dl.api.inference.posedetection
/**
* This data class represents the line connecting two points [PoseLandmark] of human's pose.
*
* @property [poseEdgeLabel] The predicted pose edge label.
* @property [label] The predicted pose edge label.
* @property [probability] The probability of the predicted class.
* @property [start] The probability of the predicted class.
* @property [end] The probability of the predicted class.
*/
public data class PoseEdge(
val poseEdgeLabel: String,
val probability: Float,
val start: PoseLandmark,
val end: PoseLandmark,
val probability: Float,
val label: String,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ package org.jetbrains.kotlinx.dl.api.inference.posedetection
/**
* This data class represents one point of the detected human's pose.
*
* @property [poseLandmarkLabel] The predicted pose landmark label.
* @property [probability] The probability of the predicted class.
* @property [x] The value of `x` coordinate.
* @property [y] The value of `y` coordinate.
* @property [probability] The probability of the predicted class.
* @property [label] The predicted pose landmark label.
*/
public data class PoseLandmark(
val poseLandmarkLabel: String,
val probability: Float,
val x: Float,
val y: Float,
val probability: Float,
val label: String,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package examples.onnx.faces

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.api.inference.onnx.inferAndCloseUsing
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.resize
import org.jetbrains.kotlinx.dl.visualization.swing.createDetectedObjectsPanel
import org.jetbrains.kotlinx.dl.visualization.swing.showFrame
import java.awt.image.BufferedImage
import java.io.File

fun main() {
val modelHub = ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))
val model = ONNXModels.FaceDetection.UltraFace320.pretrainedModel(modelHub)

model.inferAndCloseUsing(ExecutionProvider.CPU()) {
val file = getFileFromResource("datasets/poses/multi/1.jpg")
val image = ImageConverter.toBufferedImage(file)
val faces = it.detectFaces(image)

val width = 600
val resize = pipeline<BufferedImage>().resize {
outputWidth = width
outputHeight = width * image.height / image.width
}
showFrame("Detected Faces", createDetectedObjectsPanel(resize.apply(image), faces))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fun main() {
detectionModel.detectObjects(imageFile = imageFile)

detectedObjects.forEach {
println("Found ${it.classLabel} with probability ${it.probability}")
println("Found ${it.label} with probability ${it.probability}")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fun main() {
val detectedObjects = detectionModel.detectObjects(image)

detectedObjects.forEach {
println("Found ${it.classLabel} with probability ${it.probability}")
println("Found ${it.label} with probability ${it.probability}")
}

showFrame("Detection result for ${file.name}", createDetectedObjectsPanel(image, detectedObjects))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fun main() {
val detectedObjects = detectionModel.detectObjects(image)

detectedObjects.forEach {
println("Found ${it.classLabel} with probability ${it.probability}")
println("Found ${it.label} with probability ${it.probability}")
}

showFrame("Detection result for ${file.name}", createDetectedObjectsPanel(image, detectedObjects))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fun ssdLightAPI() {
detectionModel.detectObjects(imageFile = imageFile, topK = 50)

detectedObjects.forEach {
println("Found ${it.classLabel} with probability ${it.probability}")
println("Found ${it.label} with probability ${it.probability}")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fun main() {
val detectedObjects = detectionModel.detectObjects(image, topK = 20)

detectedObjects.forEach {
println("Found ${it.classLabel} with probability ${it.probability}")
println("Found ${it.label} with probability ${it.probability}")
}

val displayedImage = pipeline<BufferedImage>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fun ssdMobileLightAPI() {
detectionModel.detectObjects(imageFile = imageFile, topK = 50)

detectedObjects.forEach {
println("Found ${it.classLabel} with probability ${it.probability}")
println("Found ${it.label} with probability ${it.probability}")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,9 @@ import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDObjectDetectionModel
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
import org.jetbrains.kotlinx.dl.dataset.preprocessing.rescale
import org.jetbrains.kotlinx.dl.dataset.preprocessor.fileLoader
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.convert
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.resize
import org.jetbrains.kotlinx.dl.dataset.preprocessor.image.toFloatArray
import org.jetbrains.kotlinx.dl.dataset.preprocessor.toImageShape
import org.jetbrains.kotlinx.dl.visualization.swing.createDetectedObjectsPanel
import org.jetbrains.kotlinx.dl.visualization.swing.showFrame
import java.awt.image.BufferedImage
Expand All @@ -41,7 +35,7 @@ fun main() {
val detectedObjects = detectionModel.detectObjects(image, topK = 50)

detectedObjects.forEach {
println("Found ${it.classLabel} with probability ${it.probability}")
println("Found ${it.label} with probability ${it.probability}")
}

val displayedImage = pipeline<BufferedImage>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,21 @@ fun multiPoseDetectionMoveNet() {

for (keyPointIdx in 0..16) {
val poseLandmark = PoseLandmark(
poseLandmarkLabel = "",
x = floats[3 * keyPointIdx + 1],
y = floats[3 * keyPointIdx],
probability = floats[3 * keyPointIdx + 2]
probability = floats[3 * keyPointIdx + 2],
label = ""
)
foundPoseLandmarks.add(poseLandmark)
}

// [ymin, xmin, ymax, xmax, score]
val detectedObject = DetectedObject(
classLabel = "person",
probability = probability,
yMin = floats[51],
xMin = floats[52],
xMax = floats[54],
yMin = floats[51],
yMax = floats[53],
xMax = floats[54]
probability = probability
)
val detectedPose = DetectedPose(foundPoseLandmarks, emptyList())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ fun multiPoseDetectionMoveNetLightAPI() {
val image = ImageConverter.toBufferedImage(getFileFromResource("datasets/poses/multi/$i.jpg"))
val detectedPoses = poseDetectionModel.detectPoses(image = image, confidence = 0.05f)

detectedPoses.multiplePoses.forEach { detectedPose ->
println("Found ${detectedPose.first.classLabel} with probability ${detectedPose.first.probability}")
detectedPose.second.poseLandmarks.forEach {
println(" Found ${it.poseLandmarkLabel} with probability ${it.probability}")
detectedPoses.poses.forEach { detectedPose ->
println("Found ${detectedPose.first.label} with probability ${detectedPose.first.probability}")
detectedPose.second.landmarks.forEach {
println(" Found ${it.label} with probability ${it.probability}")
}

detectedPose.second.edges.forEach {
println(" The ${it.poseEdgeLabel} starts at ${it.start.poseLandmarkLabel} and ends with ${it.end.poseLandmarkLabel}")
println(" The ${it.label} starts at ${it.start.label} and ends with ${it.end.label}")
}
}
result[image] = detectedPoses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ fun poseDetectionMoveNetLightAPI() {
val image = ImageConverter.toBufferedImage(file)
val detectedPose = poseDetectionModel.detectPose(image)

detectedPose.poseLandmarks.forEach {
println("Found ${it.poseLandmarkLabel} with probability ${it.probability}")
detectedPose.landmarks.forEach {
println("Found ${it.label} with probability ${it.probability}")
}

detectedPose.edges.forEach {
println("The ${it.poseEdgeLabel} starts at ${it.start.poseLandmarkLabel} and ends with ${it.end.poseLandmarkLabel}")
println("The ${it.label} starts at ${it.start.label} and ends with ${it.end.label}")
}

result[image] = detectedPose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ fun poseDetectionMoveNet() {
val foundPoseLandmarks = mutableListOf<PoseLandmark>()
for (i in rawPoseLandMarks.indices) {
val poseLandmark = PoseLandmark(
poseLandmarkLabel = keypoints[i]!!,
x = rawPoseLandMarks[i][1],
y = rawPoseLandMarks[i][0],
probability = rawPoseLandMarks[i][2]
probability = rawPoseLandMarks[i][2],
label = keypoints[i]!!
)
foundPoseLandmarks.add(i, poseLandmark)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class PoseDetectionTestSuite {
model.use { poseDetectionModel ->
val imageFile = getFileFromResource("datasets/poses/single/1.jpg")
val detectedPose = poseDetectionModel.detectPose(imageFile = imageFile)
assertEquals(17, detectedPose.poseLandmarks.size)
assertEquals(17, detectedPose.landmarks.size)
assertEquals(18, detectedPose.edges.size)
}
}
Expand All @@ -42,7 +42,7 @@ class PoseDetectionTestSuite {
model.use { poseDetectionModel ->
val imageFile = getFileFromResource("datasets/poses/single/1.jpg")
val detectedPose = poseDetectionModel.detectPose(imageFile = imageFile)
assertEquals(17, detectedPose.poseLandmarks.size)
assertEquals(17, detectedPose.landmarks.size)
assertEquals(18, detectedPose.edges.size)
}
}
Expand All @@ -55,9 +55,9 @@ class PoseDetectionTestSuite {
model.use { poseDetectionModel ->
val imageFile = getFileFromResource("datasets/poses/multi/1.jpg")
val detectedPoses = poseDetectionModel.detectPoses(imageFile = imageFile)
assertEquals(3, detectedPoses.multiplePoses.size)
detectedPoses.multiplePoses.forEach {
assertEquals(17, it.second.poseLandmarks.size)
assertEquals(3, detectedPoses.poses.size)
detectedPoses.poses.forEach {
assertEquals(17, it.second.landmarks.size)
assertEquals(18, it.second.edges.size)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.InputType
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.classification.ImageRecognitionModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.Fan2D106FaceAlignmentModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.FaceDetectionModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModelMetadata
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.SinglePoseDetectionModel
Expand Down Expand Up @@ -235,4 +237,67 @@ public object ONNXModels {
}
}
}

/** Face detection models */
public sealed class FaceDetection(override val inputShape: LongArray, override val modelRelativePath: String) :
OnnxModelType<OnnxInferenceModel, FaceDetectionModel> {
override val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>>
get() = defaultPreprocessor

override fun pretrainedModel(modelHub: ModelHub): FaceDetectionModel {
return FaceDetectionModel(modelHub.loadModel(this))
}

/**
* Ultra-lightweight face detection model.
*
* Model accepts input of the shape (1 x 3 x 240 x 320)
* Model outputs two arrays (1 x 4420 x 2) and (1 x 4420 x 4) of scores and boxes.
*
* Threshold filtration and non-max suppression are applied during postprocessing.
*
* @see <a href="https://github.com/onnx/models/tree/main/vision/body_analysis/ultraface">Ultra-lightweight face detection model</a>
*/
public object UltraFace320 : FaceDetection(longArrayOf(3L, 240, 320), "ultraface_320")

/**
* Ultra-lightweight face detection model.
*
* Model accepts input of the shape (1 x 3 x 480 x 640)
* Model outputs two arrays (1 x 4420 x 2) and (1 x 4420 x 4) of scores and boxes.
*
* Threshold filtration and non-max suppression are applied during postprocessing.
*
* @see <a href="https://github.com/onnx/models/tree/main/vision/body_analysis/ultraface">Ultra-lightweight face detection model</a>
*/
public object UltraFace640 : FaceDetection(longArrayOf(3L, 480, 640), "ultraface_640")

public companion object {
public val defaultPreprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>> =
pipeline<Pair<FloatArray, TensorShape>>()
.normalize {
mean = floatArrayOf(127f, 127f, 127f)
std = floatArrayOf(128f, 128f, 128f)
channelsLast = false
}
}
}

/** Face alignment models */
public sealed class FaceAlignment<T : OnnxInferenceModel, U : InferenceModel> : OnnxModelType<T, U> {
/**
* This model is a neural network for face alignment that take RGB images of faces as input and produces coordinates of 106 faces landmarks.
*
* The model have
* - an input with the shape (1x3x192x192)
* - an output with the shape (1x212)
*/
public object Fan2d106 : FaceAlignment<OnnxInferenceModel, Fan2D106FaceAlignmentModel>() {
override val inputShape: LongArray = longArrayOf(3L, 192, 192)
override val modelRelativePath: String = "fan_2d_106"
override fun pretrainedModel(modelHub: ModelHub): Fan2D106FaceAlignmentModel {
return Fan2D106FaceAlignmentModel(modelHub.loadModel(this))
}
}
}
}
Loading

0 comments on commit 484b0e8

Please sign in to comment.