Skip to content

Commit

Permalink
[Other] Add Function For Aligning Face With Five Points (#1124)
Browse files Browse the repository at this point in the history
* 更新5点人脸对齐的代码

* 更新代码格式

* 解决comment

* update example

* 更新注释

Co-authored-by: DefTruth <[email protected]>
  • Loading branch information
Zheng-Bicheng and DefTruth authored Jan 14, 2023
1 parent 1dabfdf commit c797d31
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 13 deletions.
11 changes: 6 additions & 5 deletions examples/vision/facedet/scrfd/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)

# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")

include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)

# 添加FastDeploy依赖头文件

include_directories(${FASTDEPLOY_INCS})

add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
add_executable(infer_with_face_align_demo ${PROJECT_SOURCE_DIR}/infer_with_face_align.cc)
target_link_libraries(infer_with_face_align_demo ${FASTDEPLOY_LIBS})

add_executable(infer_without_face_align_demo ${PROJECT_SOURCE_DIR}/infer_without_face_align.cc)
target_link_libraries(infer_without_face_align_demo ${FASTDEPLOY_LIBS})
14 changes: 11 additions & 3 deletions examples/vision/facedet/scrfd/cpp/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ make -j
wget https://bj.bcebos.com/paddlehub/fastdeploy/scrfd_500m_bnkps_shape640x640.onnx
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg

# SCRFD
# CPU推理
./infer_without_face_align_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 0
# GPU推理
./infer_without_face_align_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 1
# GPU上TensorRT推理
./infer_without_face_align_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 2

# SCRFD + FaceAlign
# CPU推理
./infer_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 0
./infer_with_face_align_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 0
# GPU推理
./infer_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 1
./infer_with_face_align_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 1
# GPU上TensorRT推理
./infer_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 2
./infer_with_face_align_demo scrfd_500m_bnkps_shape640x640.onnx test_lite_face_detector_3.jpg 2
```

运行完成可视化结果如下图所示
Expand Down
115 changes: 115 additions & 0 deletions examples/vision/facedet/scrfd/cpp/infer_with_face_align.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/vision.h"

void CpuInfer(const std::string& model_file, const std::string& image_file) {
auto model = fastdeploy::vision::facedet::SCRFD(model_file);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}

auto im = cv::imread(image_file);

fastdeploy::vision::FaceDetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;

auto vis_im_list =
fastdeploy::vision::utils::AlignFaceWithFivePoints(im, res);
if (!vis_im_list.empty()) {
cv::imwrite("vis_result.jpg", vis_im_list[0]);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
}

void GpuInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::facedet::SCRFD(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}

auto im = cv::imread(image_file);

fastdeploy::vision::FaceDetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;

auto vis_im_list =
fastdeploy::vision::utils::AlignFaceWithFivePoints(im, res);
if (!vis_im_list.empty()) {
cv::imwrite("vis_result.jpg", vis_im_list[0]);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
}

void TrtInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
option.UseTrtBackend();
option.SetTrtInputShape("images", {1, 3, 640, 640});
auto model = fastdeploy::vision::facedet::SCRFD(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}

auto im = cv::imread(image_file);

fastdeploy::vision::FaceDetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;

auto vis_im_list =
fastdeploy::vision::utils::AlignFaceWithFivePoints(im, res);
if (!vis_im_list.empty()) {
cv::imwrite("vis_result.jpg", vis_im_list[0]);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
}

int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout
<< "Usage: infer_demo path/to/model path/to/image run_option, "
"e.g ./infer_model scrfd_500m_bnkps_shape640x640.onnx ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}

if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 2) {
TrtInfer(argv[1], argv[2]);
}
return 0;
}
File renamed without changes.
151 changes: 151 additions & 0 deletions fastdeploy/vision/utils/face_align.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// reference:
// https://github.com/deepinsight/insightface/blob/master/recognition/_tools_/cpp_align/face_align.h
#include "fastdeploy/vision/utils/utils.h"

namespace fastdeploy {
namespace vision {
namespace utils {

cv::Mat MeanAxis0(const cv::Mat& src) {
int num = src.rows;
int dim = src.cols;
cv::Mat output(1, dim, CV_32F);
for (int i = 0; i < dim; i++) {
float sum = 0;
for (int j = 0; j < num; j++) {
sum += src.at<float>(j, i);
}
output.at<float>(0, i) = sum / num;
}
return output;
}

cv::Mat ElementwiseMinus(const cv::Mat& A, const cv::Mat& B) {
cv::Mat output(A.rows, A.cols, A.type());
assert(B.cols == A.cols);
if (B.cols == A.cols) {
for (int i = 0; i < A.rows; i++) {
for (int j = 0; j < B.cols; j++) {
output.at<float>(i, j) = A.at<float>(i, j) - B.at<float>(0, j);
}
}
}
return output;
}

cv::Mat VarAxis0(const cv::Mat& src) {
cv::Mat temp_ = ElementwiseMinus(src, MeanAxis0(src));
cv::multiply(temp_, temp_, temp_);
return MeanAxis0(temp_);
}

int MatrixRank(cv::Mat M) {
cv::Mat w, u, vt;
cv::SVD::compute(M, w, u, vt);
cv::Mat1b non_zero_singular_values = w > 0.0001;
int rank = countNonZero(non_zero_singular_values);
return rank;
}

cv::Mat SimilarTransform(cv::Mat& dst, cv::Mat& src) {
int num = dst.rows;
int dim = dst.cols;
cv::Mat src_mean = MeanAxis0(dst);
cv::Mat dst_mean = MeanAxis0(src);
cv::Mat src_demean = ElementwiseMinus(dst, src_mean);
cv::Mat dst_demean = ElementwiseMinus(src, dst_mean);
cv::Mat A = (dst_demean.t() * src_demean) / static_cast<float>(num);
cv::Mat d(dim, 1, CV_32F);
d.setTo(1.0f);
if (cv::determinant(A) < 0) {
d.at<float>(dim - 1, 0) = -1;
}
cv::Mat T = cv::Mat::eye(dim + 1, dim + 1, CV_32F);
cv::Mat U, S, V;
cv::SVD::compute(A, S, U, V);
int rank = MatrixRank(A);
if (rank == 0) {
assert(rank == 0);
} else if (rank == dim - 1) {
if (cv::determinant(U) * cv::determinant(V) > 0) {
T.rowRange(0, dim).colRange(0, dim) = U * V;
} else {
int s = d.at<float>(dim - 1, 0) = -1;
d.at<float>(dim - 1, 0) = -1;

T.rowRange(0, dim).colRange(0, dim) = U * V;
cv::Mat diag_ = cv::Mat::diag(d);
cv::Mat twp = diag_ * V; // np.dot(np.diag(d), V.T)
cv::Mat B = cv::Mat::zeros(3, 3, CV_8UC1);
cv::Mat C = B.diag(0);
T.rowRange(0, dim).colRange(0, dim) = U * twp;
d.at<float>(dim - 1, 0) = s;
}
} else {
cv::Mat diag_ = cv::Mat::diag(d);
cv::Mat twp = diag_ * V.t(); // np.dot(np.diag(d), V.T)
cv::Mat res = U * twp; // U
T.rowRange(0, dim).colRange(0, dim) = -U.t() * twp;
}
cv::Mat var_ = VarAxis0(src_demean);
float val = cv::sum(var_).val[0];
cv::Mat res;
cv::multiply(d, S, res);
float scale = 1.0 / val * cv::sum(res).val[0];
T.rowRange(0, dim).colRange(0, dim) =
-T.rowRange(0, dim).colRange(0, dim).t();
cv::Mat temp1 = T.rowRange(0, dim).colRange(0, dim); // T[:dim, :dim]
cv::Mat temp2 = src_mean.t(); // src_mean.T
cv::Mat temp3 = temp1 * temp2; // np.dot(T[:dim, :dim], src_mean.T)
cv::Mat temp4 = scale * temp3;
T.rowRange(0, dim).colRange(dim, dim + 1) = -(temp4 - dst_mean.t());
T.rowRange(0, dim).colRange(0, dim) *= scale;
return T;
}

std::vector<cv::Mat> AlignFaceWithFivePoints(
cv::Mat& image, FaceDetectionResult& result,
std::vector<std::array<float, 2>> std_landmarks,
std::array<int, 2> output_size) {
FDASSERT(std_landmarks.size() == 5, "The landmarks.size() must be 5.")
FDASSERT(!image.empty(), "The input_image can't be empty.")
std::vector<cv::Mat> output_images(result.boxes.size());
if (result.boxes.empty()) {
FDWARNING << "The result is empty." << std::endl;
return output_images;
}

cv::Mat src(5, 2, CV_32FC1, std_landmarks.data());
for (int i = 0; i < result.landmarks.size(); i += 5) {
cv::Mat dst(5, 2, CV_32FC1, result.landmarks.data() + i);
cv::Mat m = SimilarTransform(dst, src);
cv::Mat map_matrix;
cv::Rect map_matrix_r = cv::Rect(0, 0, 3, 2);
cv::Mat(m, map_matrix_r).copyTo(map_matrix);
cv::Mat cropped_image_aligned;
cv::warpAffine(image, cropped_image_aligned, map_matrix,
{output_size[0], output_size[1]});
if (cropped_image_aligned.empty()) {
FDWARNING << "croppedImageAligned is empty." << std::endl;
}
output_images.push_back(cropped_image_aligned);
}
return output_images;
}
} // namespace utils
} // namespace vision
} // namespace fastdeploy
26 changes: 21 additions & 5 deletions fastdeploy/vision/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,32 @@ void SortDetectionResult(DetectionResult* output);
void SortDetectionResult(FaceDetectionResult* result);

// L2 Norm / cosine similarity (for face recognition, ...)
FASTDEPLOY_DECL std::vector<float> L2Normalize(
const std::vector<float>& values);
FASTDEPLOY_DECL std::vector<float>
L2Normalize(const std::vector<float>& values);

FASTDEPLOY_DECL float CosineSimilarity(const std::vector<float>& a,
const std::vector<float>& b,
bool normalized = true);

bool CropImageByBox(Mat& src_im, Mat* dst_im,
const std::vector<float>& box, std::vector<float>* center,
std::vector<float>* scale, const float expandratio = 0.3);
/** \brief Do face align for model with five points.
*
* \param[in] image The original image
* \param[in] result FaceDetectionResult
* \param[in] std_landmarks Standard face template
* \param[in] output_size The size of output mat
*/
FASTDEPLOY_DECL std::vector<cv::Mat> AlignFaceWithFivePoints(
cv::Mat& image, FaceDetectionResult& result,
std::vector<std::array<float, 2>> std_landmarks = {{38.2946f, 51.6963f},
{73.5318f, 51.5014f},
{56.0252f, 71.7366f},
{41.5493f, 92.3655f},
{70.7299f, 92.2041f}},
std::array<int, 2> output_size = {112, 112});

bool CropImageByBox(Mat& src_im, Mat* dst_im, const std::vector<float>& box,
std::vector<float>* center, std::vector<float>* scale,
const float expandratio = 0.3);

/**
* Function: for keypoint detection model, fine positioning of keypoints in
Expand Down

0 comments on commit c797d31

Please sign in to comment.