diff --git a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt index c6fe29112..b1c53a30c 100644 --- a/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt +++ b/onnx/src/androidMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/Fan2D106FaceAlignmentModel.kt @@ -6,10 +6,14 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment import android.graphics.Bitmap +import androidx.camera.core.ImageProxy import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.doWithRotation import org.jetbrains.kotlinx.dl.dataset.preprocessing.* +import org.jetbrains.kotlinx.dl.dataset.preprocessing.camerax.toBitmap import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape /** @@ -17,21 +21,34 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape * * @param [internalModel] model used to make predictions */ -public class Fan2D106FaceAlignmentModel(override val internalModel: OnnxInferenceModel) : FaceAlignmentModelBase(), +public class Fan2D106FaceAlignmentModel(override val internalModel: OnnxInferenceModel) : + FaceAlignmentModelBase(), CameraXCompatibleModel, InferenceModel by internalModel { override val outputName: String = "fc1" - override var targetRotation: Float = 0f + override var targetRotation: Int = 0 override val preprocessing: Operation> = pipeline() .resize { outputWidth = internalModel.inputDimensions[2].toInt() outputHeight = internalModel.inputDimensions[1].toInt() } - .rotate { degrees = targetRotation } + .rotate { degrees = targetRotation.toFloat() } .toFloatArray { layout = TensorLayout.NCHW } override fun copy(copiedModelName: String?, saveOptimizerState: Boolean, copyWeights: Boolean): InferenceModel { return Fan2D106FaceAlignmentModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights)) } +} + +/** + * Detects [Landmark] objects on the given [imageProxy]. + */ +public fun FaceAlignmentModelBase.detectLandmarks(imageProxy: ImageProxy): List { + if (this is CameraXCompatibleModel) { + return doWithRotation(imageProxy.imageInfo.rotationDegrees) { + detectLandmarks(imageProxy.toBitmap()) + } + } + return detectLandmarks(imageProxy.toBitmap(applyRotation = true)) } \ No newline at end of file