当前位置:   article > 正文

Caffe 损失层中loss_weight 如何存储?_caffe loss_weight

caffe loss_weight

一个网络中如果存在多个损失层的话,需要给每个损失层加上loss_weight参数,不加的话默认为1.0
但是loss_weight如何存储的呢?

这里我是从ContrastiveLossLayer::Backward_cpu中发现的:

const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] /
      static_cast<Dtype>(bottom[i]->num());
  • 1
  • 2
  • 3

其中top[0]->cpu_diff()[0]保存的即为该层的loss_weight

训练时函数调用如下:

这里写图片描述

在所有层的父类layer.hpp中会执行下列操作:

void SetUp(const vector<Blob<Dtype>*>& bottom,
      const vector<Blob<Dtype>*>& top) {
    InitMutex();
    CheckBlobCounts(bottom, top);
    LayerSetUp(bottom, top);
    Reshape(bottom, top);
    SetLossWeights(top);
  }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

先执行完LayerSetUp和Reshape的初始化操作,调用了SetLossWeights,其中caffe_set(count, loss_weight, loss_multiplier);将loss_weight赋值给top[0]->cpu_diff()。

/**
  * Called by SetUp to initialize the weights associated with any top blobs in
  * the loss function. Store non-zero loss weights in the diff blob.
  */
 inline void SetLossWeights(const vector<Blob<Dtype>*>& top) {
   const int num_loss_weights = layer_param_.loss_weight_size();
   if (num_loss_weights) {
     CHECK_EQ(top.size(), num_loss_weights) << "loss_weight must be "
         "unspecified or specified once per top blob.";
     for (int top_id = 0; top_id < top.size(); ++top_id) {
       const Dtype loss_weight = layer_param_.loss_weight(top_id);
       if (loss_weight == Dtype(0)) { continue; }
       this->set_loss(top_id, loss_weight);
       const int count = top[top_id]->count();
       Dtype* loss_multiplier = top[top_id]->mutable_cpu_diff();
       caffe_set(count, loss_weight, loss_multiplier);
     }
   }
 }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

从const Dtype loss_weight = layer_param_.loss_weight(top_id);可以看到loss_wight可以直接从layer_param_中获取

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/133285
推荐阅读
相关标签
  

闽ICP备14008679号