Monday 18 February 2019

Face recognition with mxnet, dlib and opencv

   In this post I will show you how to implement an industrial level, portable face recognition application with a small, reuseable example, without relying on any commercial library(except of Qt5, unless the module I use in this example support LGPL license).

    Before deep learning become main stream technology in computer vision fields, 2D face recognition only works well under strict environments, this make it an impractical technology.

    Thanks to the contributions of open source communities like dlib, opencv and mxnet, today, high accuracy 2D face recognition is not a difficult problem anymore.

    Before we start, let us see an interesting example(video_00).


video_00

     Although different angles and expressions affect the confidence value a lot, but in most of the time the algorithm still able to find out the most similar faces from 25 faces.

     The flow of face recognition on github are composed by 4 critical steps.


pic_00

Detect face by dlib   

std::vector<mmod_rect> face_detector::forward_lazy(const cv::Mat &input)
{
    //make sure input image got 3 channels
    CV_Assert(input.channels() == 3);

    //Resize the input image to certain width, 
    //The bigger the face_detect_width_, more 
    //faces could be detected, but will consume
    //more memory, and slower
    if(input.cols != face_detect_width_){
        //resize_cache_ is a simple trick to reduce the
        //number of memory allocation
        double const ratio = face_detect_width_ / 
                             static_cast<double>(input.cols);
        cv::resize(input, resize_cache_, {}, ratio, ratio);
    }else{
        resize_cache_ = input;
    }

    //1. convert cv::Mat to dlib::matrix
    //2. Swap bgr channel to rgb
    img_.set_size(resize_cache_.rows, resize_cache_.cols);
    dlib::assign_image(img_, dlib::cv_image<bgr_pixel>(resize_cache_));

    return net_(img_);
}
 
    Face detector of dlib perform very well, you can check the results on their post.

    If you want to know the details, please study the example provided by dlib, if you want to know more options, please study the excellent post of Learn Opencv.

Perform face alignment by dlib

    We can treat face alignment as a data normalization skills develop for face recognition, usually you would align the faces before training your model, and align the faces when predict, this could help you obtain higher accuracy.

    With dlib, face alignment become very simple. Just a few lines of codes.

//rect contain the roi of the face
dlib::matrix<rgb_pixel> face_detector::
get_aligned_face(const mmod_rect &rect)
{
    //Type of pose_model_ is dlib::shape_predictor
    //It return the landmarks of the face
    auto shape = pose_model_(img_, rect);
    matrix<rgb_pixel> face_chip;
    auto const details = 
          get_face_chip_details(shape, face_aligned_size_, 0.25);
    //extract face after aligned from the image
    extract_image_chip(img_, details, face_chip);
    return face_chip;
}

Extract features of face by mxnet

    This section will need to load the model from mxnet, unlike dlib or opencv, the c++ api of mxnet is more complicated, if you do not know how to load the model of mxnet yet, I recommend you study this post.

    This section is the most complicated part, because it contains three main points

1.  Extract the features of faces.
2.  Perform batch processing.
3.  Convert aligned face of dlib(store as matrix<rgb_pixel>) to a memory continuous float array with
the format expected by the mxnet model.


A.Load the model with variable batch size


    In order to load the model which support variable batch size, all we need to do is add one more argument to the argument list.


std::unique_ptr<Executor> create_executor(const std::string &model_params,
                                          const std::string &model_symbols,
                                          const Context &context,
                                          const Shape &input_shape)
{    
    Symbol net;
    std::map<std::string, NDArray> args, auxs;
    load_check_point(model_params, model_symbols, &net, 
                     &args, &auxs, context);

    //if "data" throw exception, try another key, like "data0"
    args["data"] = NDArray(input_shape, context, false);
    //we only need to add the new key if batch size larger than 1
    if(input_shape[0] > 1){
        //all we need is the new key "data1"
        args["data1"] = NDArray(Shape(1), context, false);
    }

    std::unique_ptr<Executor> executor;
    executor.reset(net.SimpleBind(context, 
                                  args, 
                                  std::map<std::string, NDArray>(),
                                  std::map<std::string, OpReqType>(), 
                                  auxs));

    return executor;
}

B.Convert aligned face to array

    Unlike the example of yolo v3, the input data of deepsight need more preprocess steps before you can feed the aligned face into the model. Instead of arranged the pixels as rgb order, you need to split each channels of the face into separate "page". Simply put, instead of arrange the pixels as

R1G1B1R2G2B2......RnGnBn

We should arrange the pixels as

R1R2....RNG1G2......GNB1B2.....BN


//using dlib_const_images_ptr = std::vector<matrix<rgb_pixel> const*>;
void face_key_extractor::
dlib_matrix_to_float_array(dlib_const_images_ptr const &rgb_image)
{
    size_t index = 0;
    for(size_t i = 0; i != rgb_image.size(); ++i){
        for(size_t ch = 0; ch != 3; ++ch){
            for(long row = 0; row != rgb_image[i]->nr(); ++row){
                for(long col = 0; col != rgb_image[i]->nc(); ++col){
                    auto const &pix = (*rgb_image[i])(row, col);
                    switch(ch){
                    case 0:
                        //image_vector_ is a std::vector<float>, resized in 
                        //constructor.

                        //image_vector_.resize(params_->shape_.Size())
                        //params_->shape_.Size() return total number 
                        //of elements in the tenso
                        image_vector_[index++] = pix.red;
                        break;
                    case 1:
                        image_vector_[index++] = pix.green;
                        break;
                    case 2:
                        image_vector_[index++] = pix.blue;
                        break;
                    default:
                        break;
                    }
                }
            }
        }
    }
}


C.Forward aligned faces with variable batch size

    There are two things you must know before we dive into the source codes.

1. To avoid memory reallocation, we must allocate memory for the largest possible batch size and reuse that same memory when batch size is smaller.
2.  The batch size of the float array input to the model must be the same as the largest possible batch size


//input contains all of the aligned faces detected from the image
std::vector<face_key> face_key_extractor::
forward(const std::vector<dlib::matrix<dlib::rgb_pixel> > &input)
{
    if(input.empty()){
        return {};
    }

    //Size of the input may not divisible by batch size
    //That is why we need some preprocess job to make sure
    //features of every faces are extracted
    auto const forward_count = static_cast<size_t>(
         std::ceil(input.size() / static_cast<float>(params_->shape_[0])));
    std::vector<face_key> result;
    for(size_t i = 0, index = 0; i != forward_count; ++i){
        dlib_const_images_ptr faces;
        for(size_t j = 0; 
            j != params_->shape_[0] && index < input.size(); ++j){
            faces.emplace_back(&input[index++]);
        }
        dlib_matrix_to_float_array(faces);
        auto features = 
             forward(image_vector_, static_cast<size_t>(faces.size()));
        std::move(std::begin(features), std::end(features), 
                  std::back_inserter(result));
    }

    return result;
} 

D.Extract features of faces

std::vector<face_key> face_key_extractor::
forward(const std::vector<float> &input, size_t batch_size)
{
    executor_->arg_dict()["data"].SyncCopyFromCPU(input.data(), 
                                                  input.size());
    //data1 tell the executor, how many face(s) need to process
    executor_->arg_dict()["data1"] = batch_size;
    executor_->Forward(false);
    std::vector<face_key> result;
    if(!executor_->outputs.empty()){
        //shape of features is [batch_size, 512]
        auto features = executor_->outputs[0].Copy(Context(kCPU, 0));
        Shape const shape(1, step_per_feature);
        features.WaitToRead();
        //split features into and array
        for(size_t i = 0; i != batch_size; ++i){
            //step_per_feature is 512, memory 
            //of NDArray is continuous make things easier
            NDArray feature(features.GetData() + i * step_per_feature, 
                            shape, Context(kCPU, 0));
            result.emplace_back(std::move(feature));
        }
        return result;
    }

    return result;
} 
 
 

Find most similar faces from database

    I use cosine similarity to compare similarity in this small example, it is quite easy with the help of 
opencv. 

A.Similarity compare


double face_key::similarity(const face_key &input) const
{
    CV_Assert(key_.GetData() != nullptr && 
              input.key_.GetData() != nullptr);

    cv::Mat_<float> const key1(1, 512, 
                               const_cast<float*>(input.key_.GetData()), 0);
    cv::Mat_<float> const key2(1, 512, 
                               const_cast<float*>(key_.GetData()), 0);
    auto const denominator = std::sqrt(key1.dot(key1) * key2.dot(key2));
    if(denominator != 0.0){
        return key2.dot(key1) / denominator;
    }

    return 0;
}

B.Find most similar face

    Find the most similar face is really easy, all we need to do is compare the features stored in the array one by one and return the one with the highest confidence.


//for simplicity, I put struct at here in this blog
struct id_info
{
   double confident_ = -1.0;
   std::string id_;
};

struct face_info
{
   face_key key_;
   std::string id_;
};

face_reg_db::id_info face_reg_db::
find_most_similar_face(const face_key &input) const
{
    id_info result;
    //type of face_keys_ is std::vector<face_info>
    for(size_t i = 0; i != face_keys_.size(); ++i){
        auto const confident = 
             face_keys_[i].key_.similarity(input);
        if(confident > result.confident_){
            result.confident_ = confident;
            result.id_ = face_keys_[i].id_;
        }
    }

    return result;
}

Summary

    In today's post, I show you the most critical parts of face recognize with opencv, dlib and mxnet. I believe this is a great starting point if you want to build a high quality face recognition app by c++.

    Real world applications are much more complicated than this small example since they always need to support more features and required to be efficient, but no matter how complex they are, the main flow of the 2D face recognition are almost the same as this post show you.

5 comments:

  1. Hi,

    We have 1080p rtsp video feed and 1060ti nvidia card. It looks like bit slow and cant reach the video feed even we reduce the fps 14 .

    Do we need to have another step to increase DlibCNN detector such as sending two frames at once ?

    what is your sugesstion for 1080p rtsp video stream for fast and accurate detection ?

    ReplyDelete
  2. I would suggest you measure the bottleneck first before you perform any optimization

    ReplyDelete
  3. Thanks for your interesting information's in this blog.It is very much useful for me to improve my knowledge for more information. Face Recognition

    ReplyDelete
  4. Best article, very useful and explanation. Your post is extremely incredible. Thank you very much for the new information.Face Recognition

    ReplyDelete
  5. Thanks for sharing wonderful info, Found your post interesting, can not wait to see more from you. Good luck for upcoming post!!! You can also read more from Facial Recognition

    ReplyDelete