Skip to content

Commit

Permalink
fixup! Add face alignment model for android
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Sep 30, 2022
1 parent 5dbecf1 commit d5383a4
Showing 1 changed file with 20 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,49 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment

import android.graphics.Bitmap
import androidx.camera.core.ImageProxy
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

/**
* The light-weight API for solving Face Alignment task.
*
* @param [internalModel] model used to make predictions
*/
public class Fan2D106FaceAlignmentModel(override val internalModel: OnnxInferenceModel) : FaceAlignmentModelBase<Bitmap>(),
public class Fan2D106FaceAlignmentModel(override val internalModel: OnnxInferenceModel) :
FaceAlignmentModelBase<Bitmap>(),
CameraXCompatibleModel, InferenceModel by internalModel {

override val outputName: String = "fc1"
override var targetRotation: Float = 0f
override var targetRotation: Int = 0

override val preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>> = pipeline<Bitmap>()
.resize {
outputWidth = internalModel.inputDimensions[2].toInt()
outputHeight = internalModel.inputDimensions[1].toInt()
}
.rotate { degrees = targetRotation }
.rotate { degrees = targetRotation.toFloat() }
.toFloatArray { layout = TensorLayout.NCHW }

override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel {
return Fan2D106FaceAlignmentModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights))
}
}

/**
* Detects [Landmark] objects on the given [imageProxy].
*/
public fun FaceAlignmentModelBase<Bitmap>.detectLandmarks(imageProxy: ImageProxy): List<Landmark> {
if (this is CameraXCompatibleModel) {
return doWithRotation(imageProxy.imageInfo.rotationDegrees) {
detectLandmarks(imageProxy.toBitmap())
}
}
return detectLandmarks(imageProxy.toBitmap(applyRotation = true))
}

0 comments on commit d5383a4

Please sign in to comment.