当前位置:   article > 正文

Caffe中增加新的layer以及Caffe中triplet loss layer的实现_caffe 构建新层 tripletloss

caffe 构建新层 tripletloss

关于Tripletloss的原理,目标函数和梯度推导在上一篇博客中已经讲过了,具体见:Tripletloss原理以及梯度推导,这篇博文主要是讲caffe下实现Tripletloss,编程菜鸟,如果有写的不优化的地方,欢迎指出。

尊重原创,转载请注明:http://blog.csdn.net/tangwei2014

1.如何在caffe中增加新的layer

新版的caffe中增加新的layer,变得轻松多了,概括说来,分四步:

1)在./src/caffe/proto/caffe.proto 中增加对应layer的paramter message;

2)在./include/caffe/***layers.hpp中增加该layer的类的声明,***表示有common_layers.hpp,data_layers.hpp, neuron_layers.hpp, vision_layers.hpp 和loss_layers.hpp等;

3)在./src/caffe/layers/目录下新建.cpp和.cu文件,进行类实现。

4)在./src/caffe/gtest/中增加layer的测试代码,对所写的layer前传和反传进行测试,测试还包括速度。

最后一步很多人省了,或者没意识到,但是为保证代码正确,建议还是严格进行测试,磨刀不误砍柴功。

2.caffe中实现Triplettloss layer

1.caffe.proto中增加Triplettloss layer的定义

首先在message LayerParameter中追加 optional TripletLossParameter Triplet_loss_param = 138; 其中138是我目前LayerParameter message中现有元素的个数,具体是多少,可以看LayerParameter message上面注释中的:

//LayerParameter next available layer-specific ID: 134 (last added:reshape_param)

然后增加Message:

  1. message TripletLossParameter {
  2. // margin for dissimilar pair
  3. optional float margin = 1 [default = 1.0];
  4. }

其中 margin就是定义Tripletloss原理以及梯度推导所讲的alpha。

2.在./include/caffe/loss_layers.hpp中增加Tripletloss layer的类的声明

具体解释见注释,主要的是定义了一些变量,用来在前传中存储中间计算结果,以便在反传的时候避免重复计算。

  1. /**
  2. * @brief Computes the Tripletloss
  3. */
  4. template <typename Dtype>
  5. class TripletLossLayer : publicLossLayer<Dtype> {
  6. public:
  7. explicit TripletLossLayer(const LayerParameter& param)
  8. : LossLayer<Dtype>(param){}
  9. virtual void LayerSetUp(const vector<Blob<Dtype>*>&bottom,
  10. constvector<Blob<Dtype>*>& top);
  11. virtual inline int ExactNumBottomBlobs() const { return 4; }
  12. virtual inline const char* type() const { return "TripletLoss";}
  13. /**
  14. * Unlike most loss layers, in the TripletLossLayer we can backpropagate
  15. * to the first three inputs.
  16. */
  17. virtual inline bool AllowForceBackward(const int bottom_index) const {
  18. return bottom_index != 3;
  19. }
  20. protected:
  21. virtual void Forward_cpu(const vector<Blob<Dtype>*>&bottom,
  22. constvector<Blob<Dtype>*>& top);
  23. virtual void Forward_gpu(const vector<Blob<Dtype>*>&bottom,
  24. constvector<Blob<Dtype>*>& top);
  25. virtual void Backward_cpu(const vector<Blob<Dtype>*>&top,
  26. const vector<bool>&propagate_down, const vector<Blob<Dtype>*>& bottom);
  27. virtual void Backward_gpu(const vector<Blob<Dtype>*>&top,
  28. const vector<bool>&propagate_down, const vector<Blob<Dtype>*>& bottom);
  29. Blob<Dtype> diff_ap_; //cached for backward pass
  30. Blob<Dtype> diff_an_; //cached for backward pass
  31. Blob<Dtype> diff_pn_; //cached for backward pass
  32. Blob<Dtype> diff_sq_ap_; //cached for backward pass
  33. Blob<Dtype> diff_sq_an_; //tmp storage for gpu forward pass
  34. Blob<Dtype> dist_sq_ap_; //cached for backward pass
  35. Blob<Dtype> dist_sq_an_; //cached for backward pass
  36. Blob<Dtype> summer_vec_; //tmp storage for gpu forward pass
  37. Blob<Dtype> dist_
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号