Skip to content

Commit

Permalink
Develop (#19)
Browse files Browse the repository at this point in the history
* Add detection evaluation function (#37)

* Detection evaluation function

* Add license

Co-authored-by: Jason <[email protected]>

* Add UltraFace Model support  (PaddlePaddle#43)

* update .gitignore

* Added checking for cmake include dir

* fixed missing trt_backend option bug when init from trt

* remove un-need data layout and add pre-check for dtype

* changed RGB2BRG to BGR2RGB in ppcls model

* add model_zoo yolov6 c++/python demo

* fixed CMakeLists.txt typos

* update yolov6 cpp/README.md

* add yolox c++/pybind and model_zoo demo

* move some helpers to private

* fixed CMakeLists.txt typos

* add normalize with alpha and beta

* add version notes for yolov5/yolov6/yolox

* add copyright to yolov5.cc

* revert normalize

* fixed some bugs in yolox

* Add UltraFace Model support

Co-authored-by: huangjianhui <[email protected]>
Co-authored-by: Jason <[email protected]>
Co-authored-by: DefTruth <[email protected]>
  • Loading branch information
4 people authored Jul 27, 2022
1 parent 2330414 commit dca2a97
Show file tree
Hide file tree
Showing 26 changed files with 1,552 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ fastdeploy/LICENSE*
fastdeploy/ThirdPartyNotices*
*.so*
fastdeploy/libs/third_libs
csrcs/fastdeploy/core/config.h
1 change: 1 addition & 0 deletions csrcs/fastdeploy/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "fastdeploy/core/config.h"
#ifdef ENABLE_VISION
#include "fastdeploy/vision/deepcam/yolov5face.h"
#include "fastdeploy/vision/linzaer/ultraface.h"
#include "fastdeploy/vision/megvii/yolox.h"
#include "fastdeploy/vision/meituan/yolov6.h"
#include "fastdeploy/vision/ppcls/model.h"
Expand Down
35 changes: 35 additions & 0 deletions csrcs/fastdeploy/vision/linzaer/linzaer_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/pybind/main.h"

namespace fastdeploy {
void BindLinzaer(pybind11::module& m) {
auto linzaer_module = m.def_submodule(
"linzaer",
"https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB");
pybind11::class_<vision::linzaer::UltraFace, FastDeployModel>(linzaer_module,
"UltraFace")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::linzaer::UltraFace& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::FaceDetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::linzaer::UltraFace::size);
}
} // namespace fastdeploy
220 changes: 220 additions & 0 deletions csrcs/fastdeploy/vision/linzaer/ultraface.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "fastdeploy/vision/linzaer/ultraface.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"

namespace fastdeploy {

namespace vision {

namespace linzaer {

UltraFace::UltraFace(const std::string& model_file,
const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}

bool UltraFace::Initialize() {
// parameters for preprocess
size = {320, 240};

if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
return true;
}

bool UltraFace::Preprocess(
Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
// ultraface's preprocess steps
// 1. resize
// 2. BGR->RGB
// 3. HWC->CHW
int resize_w = size[0];
int resize_h = size[1];
if (resize_h != mat->Height() || resize_w != mat->Width()) {
Resize::Run(mat, resize_w, resize_h);
}

BGR2RGB::Run(mat);
// Compute `result = mat * alpha + beta` directly by channel
// Reference: detect_imgs_onnx.py#L73
std::vector<float> alpha = {1.0f / 128.0f, 1.0f / 128.0f, 1.0f / 128.0f};
std::vector<float> beta = {-127.0f * (1.0f / 128.0f),
-127.0f * (1.0f / 128.0f),
-127.0f * (1.0f / 128.0f)}; // RGB;
Convert::Run(mat, alpha, beta);

// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};

HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
}

bool UltraFace::Postprocess(
std::vector<FDTensor>& infer_result, FaceDetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold) {
// ultraface has 2 output tensors, scores & boxes
FDASSERT(
(infer_result.size() == 2),
"The default number of output tensor must be 2 according to ultraface.");
FDTensor& scores_tensor = infer_result.at(0); // (1,4420,2)
FDTensor& boxes_tensor = infer_result.at(1); // (1,4420,4)
FDASSERT((scores_tensor.shape[0] == 1), "Only support batch =1 now.");
FDASSERT((boxes_tensor.shape[0] == 1), "Only support batch =1 now.");

result->Clear();
// must be setup landmarks_per_face before reserve.
// ultraface detector does not detect landmarks by default.
result->landmarks_per_face = 0;
if (scores_tensor.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
if (boxes_tensor.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}

float* scores_ptr = static_cast<float*>(scores_tensor.Data());
float* boxes_ptr = static_cast<float*>(boxes_tensor.Data());
const size_t num_bboxes = boxes_tensor.shape[1]; // e.g 4420
// fetch original image shape
auto iter_ipt = im_info.find("input_shape");
FDASSERT((iter_ipt != im_info.end()),
"Cannot find input_shape from im_info.");
float ipt_h = iter_ipt->second[0];
float ipt_w = iter_ipt->second[1];

// decode bounding boxes
for (size_t i = 0; i < num_bboxes; ++i) {
float confidence = scores_ptr[2 * i + 1];
// filter boxes by conf_threshold
if (confidence <= conf_threshold) {
continue;
}
float x1 = boxes_ptr[4 * i + 0] * ipt_w;
float y1 = boxes_ptr[4 * i + 1] * ipt_h;
float x2 = boxes_ptr[4 * i + 2] * ipt_w;
float y2 = boxes_ptr[4 * i + 3] * ipt_h;
result->boxes.emplace_back(std::array<float, 4>{x1, y1, x2, y2});
result->scores.push_back(confidence);
}

if (result->boxes.size() == 0) {
return true;
}

utils::NMS(result, nms_iou_threshold);

// scale and clip box
for (size_t i = 0; i < result->boxes.size(); ++i) {
result->boxes[i][0] = std::max(result->boxes[i][0], 0.0f);
result->boxes[i][1] = std::max(result->boxes[i][1], 0.0f);
result->boxes[i][2] = std::max(result->boxes[i][2], 0.0f);
result->boxes[i][3] = std::max(result->boxes[i][3], 0.0f);
result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f);
result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f);
result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f);
result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f);
}
return true;
}

bool UltraFace::Predict(cv::Mat* im, FaceDetectionResult* result,
float conf_threshold, float nms_iou_threshold) {
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif

Mat mat(*im);
std::vector<FDTensor> input_tensors(1);

std::map<std::string, std::array<float, 2>> im_info;

// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};

if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}

#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(0, "Preprocess")
TIMERECORD_START(1)
#endif

input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(1, "Inference")
TIMERECORD_START(2)
#endif

if (!Postprocess(output_tensors, result, im_info, conf_threshold,
nms_iou_threshold)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}

#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(2, "Postprocess")
#endif
return true;
}

} // namespace linzaer
} // namespace vision
} // namespace fastdeploy
84 changes: 84 additions & 0 deletions csrcs/fastdeploy/vision/linzaer/ultraface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"

namespace fastdeploy {

namespace vision {

namespace linzaer {

class FASTDEPLOY_DECL UltraFace : public FastDeployModel {
public:
// 当model_format为ONNX时,无需指定params_file
// 当model_format为Paddle时,则需同时指定model_file & params_file
UltraFace(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);

// 定义模型的名称
std::string ModelName() const {
return "Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB";
}

// 模型预测接口,即用户调用的接口
// im 为用户的输入数据,目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
// conf_threshold 为后处理的参数
// nms_iou_threshold 为后处理的参数
virtual bool Predict(cv::Mat* im, FaceDetectionResult* result,
float conf_threshold = 0.7f,
float nms_iou_threshold = 0.3f);

// 以下为模型在预测时的一些参数,基本是前后处理所需
// 用户在创建模型后,可根据模型的要求,以及自己的需求
// 对参数进行修改
// tuple of (width, height), default (320, 240)
std::vector<int> size;

private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize();

// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据,传给后端进行推理
// im_info为预处理过程保存的数据,在后处理中需要用到
bool Preprocess(Mat* mat, FDTensor* outputs,
std::map<std::string, std::array<float, 2>>* im_info);

// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
// im_info 为预处理记录的信息,后处理用于还原box
// conf_threshold 后处理时过滤box的置信度阈值
// nms_iou_threshold 后处理时NMS设定的iou阈值
bool Postprocess(std::vector<FDTensor>& infer_result,
FaceDetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold);

// 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致
bool IsDynamicInput() const { return is_dynamic_input_; }

bool is_dynamic_input_;
};

} // namespace linzaer
} // namespace vision
} // namespace fastdeploy
2 changes: 2 additions & 0 deletions csrcs/fastdeploy/vision/vision_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void BindMeituan(pybind11::module& m);
void BindMegvii(pybind11::module& m);
void BindDeepCam(pybind11::module& m);
void BindRangiLyu(pybind11::module& m);
void BindLinzaer(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m);
#endif
Expand Down Expand Up @@ -69,6 +70,7 @@ void BindVision(pybind11::module& m) {
BindMegvii(m);
BindDeepCam(m);
BindRangiLyu(m);
BindLinzaer(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m);
#endif
Expand Down
Loading

0 comments on commit dca2a97

Please sign in to comment.