当前位置:   article > 正文

ncnn源码解析(五):执行器Extractor_ncnn::extractor extract

ncnn::extractor extract

前面大致总结了一下ncnn模型载入的流程,模型载入之后,就是新建一个Extractor,然后设置输入,获取输出:

  1. ncnn::Extractor ex = net.create_extractor();
  2. ex.set_num_threads(4);
  3. ex.input("data", in);
  4. ncnn::Mat out;
  5. ex.extract("detection_out", out);

        现在可以看一下Extractor的定义了:

  1. class Extractor
  2. {
  3. public:
  4. // enable light mode
  5. // intermediate blob will be recycled when enabled
  6. // enabled by default
  7. // 设置light模式
  8. void set_light_mode(bool enable);
  9. // set thread count for this extractor
  10. // this will overwrite the global setting
  11. // default count is system depended
  12. // 设置线程数
  13. void set_num_threads(int num_threads);
  14. // set blob memory allocator
  15. // 设置blob的内存分配器
  16. void set_blob_allocator(Allocator* allocator);
  17. // set workspace memory allocator
  18. // 设置工作空间的内存分配器
  19. void set_workspace_allocator(Allocator* allocator);
  20. #if NCNN_STRING
  21. // set input by blob name
  22. // return 0 if success
  23. // 设置网络输入:字符串layer名
  24. int input(const char* blob_name, const Mat& in);
  25. // get result by blob name
  26. // return 0 if success
  27. // 设置提取器的输入:得到对应输出
  28. int extract(const char* blob_name, Mat& feat);
  29. #endif // NCNN_STRING
  30. // set input by blob index
  31. // return 0 if success
  32. // 设置int类型blob索引及输入
  33. int input(int blob_index, const Mat& in);
  34. // get result by blob index
  35. // return 0 if success
  36. // 设置int类型blob索引及输出
  37. int extract(int blob_index, Mat& feat);
  38. protected:
  39. // 对外提供create_extractor接口
  40. friend Extractor Net::create_extractor() const;
  41. Extractor(const Net* net, int blob_count);
  42. private:
  43. // 网络
  44. const Net* net;
  45. // blob的mat
  46. std::vector<Mat> blob_mats;
  47. // 选项
  48. Option opt;
  49. };

        将vulkan部分代码剔除掉,不难发现,Extractor里面就这么多内容,除了设置option的接口之外,就只剩下我们需要使用的几个接口函数了:

  1. // 创建Extractor
  2. Extractor Net::create_extractor() const
  3. {
  4. return Extractor(this, blobs.size());
  5. }

        内部调用了接口为,就是将blob_mat数组resize到网络的blob数目大小,然后设置了一下选项:

  1. // 执行器
  2. Extractor::Extractor(const Net* _net, int blob_count) : net(_net)
  3. {
  4. blob_mats.resize(blob_count);
  5. opt = net->opt;
  6. }

        然后就是input接口:

  1. // 设置输入
  2. int Extractor::input(const char* blob_name, const Mat& in)
  3. {
  4. // 获取输入模块对应index
  5. int blob_index = net->find_blob_index_by_name(blob_name);
  6. if (blob_index == -1)
  7. return -1;
  8. // 调用直接用index的设置input方法
  9. return input(blob_index, in);
  10. }

        内部调用的接口为:

  1. // 输入为index的输入接口
  2. int Extractor::input(int blob_index, const Mat& in)
  3. {
  4. if (blob_index < 0 || blob_index >= (int)blob_mats.size())
  5. return -1;
  6. // 设置blob_index对应Mat
  7. blob_mats[blob_index] = in;
  8. return 0;
  9. }

        这里就是设置输入对应blob值。

        最后一个接口就是extract接口:

  1. // 将输入string类型name转换成对应的索引
  2. int Extractor::extract(const char* blob_name, VkMat& feat, VkCompute& cmd)
  3. {
  4. int blob_index = net->find_blob_index_by_name(blob_name);
  5. if (blob_index == -1)
  6. return -1;
  7. return extract(blob_index, feat, cmd);
  8. }

        这里调用的接口为:

  1. // 提取特征
  2. int Extractor::extract(int blob_index, Mat& feat)
  3. {
  4. if (blob_index < 0 || blob_index >= (int)blob_mats.size())
  5. return -1;
  6. int ret = 0;
  7. // 如果输出blob为空
  8. if (blob_mats[blob_index].dims == 0)
  9. {
  10. // 查找输出blob对应的生产者
  11. int layer_index = net->blobs[blob_index].producer;
  12. // 前向推理
  13. ret = net->forward_layer(layer_index, blob_mats, opt);
  14. }
  15. // 输出特征
  16. feat = blob_mats[blob_index];
  17. if (opt.use_packing_layout)
  18. {
  19. // 对特征进行unpack
  20. Mat bottom_blob_unpacked;
  21. convert_packing(feat, bottom_blob_unpacked, 1, opt);
  22. feat = bottom_blob_unpacked;
  23. }
  24. return ret;
  25. }

        这里就是调用各层前向推理forward_layer方法来进行推理的,这个对应于特定层的推理过程,后面总结各个层的时候再说。

参考资料:

[1] https://github.com/Tencent/ncnn

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/703326
推荐阅读
相关标签
  

闽ICP备14008679号