当前位置:   article > 正文

从零构建深度学习推理框架-3 手写算子relu_relu算子

relu算子

Relu介绍:

f(x) = \left\{\begin{matrix}x , x>thresh & & \\0,x<thresh & & \end{matrix}\right.

 relu是一个非线性激活函数,可以避免梯度消失,过拟合等情况。我们一般将thresh设为0。

operator类:

  1. #ifndef KUIPER_COURSE_INCLUDE_OPS_OP_HPP_
  2. #define KUIPER_COURSE_INCLUDE_OPS_OP_HPP_
  3. namespace kuiper_infer {
  4. enum class OpType {
  5. kOperatorUnknown = -1,
  6. kOperatorRelu = 0,
  7. };
  8. class Operator {
  9. public:
  10. OpType op_type_ = OpType::kOperatorUnknown; //不是一个具体节点 制定为unknown
  11. virtual ~Operator() = default; //
  12. explicit Operator(OpType op_type);
  13. };

这里的  kOperatorUnknown = -1 , kOperatorRelu = 0分别是他们的代号

operator是一个父类,我们的relu就要继承于这个父类

  1. class ReluOperator : public Operator {
  2. public:
  3. ~ReluOperator() override = default;
  4. explicit ReluOperator(float thresh);
  5. void set_thresh(float thresh);
  6. float get_thresh() const;
  7. private:
  8. // 需要传递到reluLayer中,怎么传递?
  9. float thresh_ = 0.f; // 用于过滤tensor<float>值当中大于thresh的部分
  10. // relu存的变量只有thresh
  11. // stride padding kernel_size 这些是到时候convOperator需要的
  12. // operator起到了属性存储、变量的作用
  13. // operator所有子类不负责具体运算
  14. // 具体运算由另外一个类Layer类负责
  15. // y =x , if x >=0 y = 0 if x < 0
  16. };

 operator起到了属性存储、变量的作用
 operator所有子类不负责具体运算
 具体运算由另外一个类Layer类负责

layer类:

  1. class Layer {
  2. public:
  3. explicit Layer(const std::string &layer_name);
  4. virtual void Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,
  5. std::vector<std::shared_ptr<Tensor<float>>> &outputs);
  6. // reluLayer中 inputs 等于 x , outputs 等于 y= x,if x>0
  7. // 计算得到的结果放在y当中,x是输入,放在inputs中
  8. virtual ~Layer() = default;
  9. private:
  10. std::string layer_name_; //relu layer "relu"
  11. };

父类只保留了一个layer_name属性和两个方法。

具体的在relu_layer这个class中

  1. class ReluLayer : public Layer {
  2. public:
  3. ~ReluLayer() override = default;
  4. // 通过这里,把relu_op中的thresh告知给relu layer, 因为计算的时候要用到
  5. explicit ReluLayer(const std::shared_ptr<Operator> &op);
  6. // 执行relu 操作的具体函数Forwards
  7. void Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,
  8. std::vector<std::shared_ptr<Tensor<float>>> &outputs) override;
  9. // 下节的内容,不用管
  10. static std::shared_ptr<Layer> CreateInstance(const std::shared_ptr<Operator> &op);
  11. private:
  12. std::unique_ptr<ReluOperator> op_;
  13. };

具体的方法实现:

  1. ReluLayer::ReluLayer(const std::shared_ptr<Operator> &op) : Layer("Relu") {
  2. CHECK(op->op_type_ == OpType::kOperatorRelu) << "Operator has a wrong type: " << int(op->op_type_);
  3. // dynamic_cast是什么意思? 就是判断一下op指针是不是指向一个relu_op类的指针
  4. // 这边的op不是ReluOperator类型的指针,就报错
  5. // 我们这里只接受ReluOperator类型的指针
  6. // 父类指针必须指向子类ReluOperator类型的指针
  7. // 为什么不讲构造函数设置为const std::shared_ptr<ReluOperator> &op?
  8. // 为了接口统一,具体下节会说到
  9. ReluOperator *relu_op = dynamic_cast<ReluOperator *>(op.get());
  10. CHECK(relu_op != nullptr) << "Relu operator is empty";
  11. // 一个op实例和一个layer 一一对应 这里relu op对一个relu layer
  12. // 对应关系
  13. this->op_ = std::make_unique<ReluOperator>(relu_op->get_thresh());
  14. }
  15. void ReluLayer::Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,
  16. std::vector<std::shared_ptr<Tensor<float>>> &outputs) {
  17. // relu 操作在哪里,这里!
  18. // 我需要该节点信息的时候 直接这么做
  19. // 实行了属性存储和运算过程的分离!!!!!!!!!!!!!!!!!!!!!!!!
  20. //x就是inputs y = outputs
  21. CHECK(this->op_ != nullptr);
  22. CHECK(this->op_->op_type_ == OpType::kOperatorRelu);
  23. const uint32_t batch_size = inputs.size(); //一批x,放在vec当中,理解为batchsize数量的tensor,需要进行relu操作
  24. for (int i = 0; i < batch_size; ++i) {
  25. CHECK(!inputs.at(i)->empty());
  26. const std::shared_ptr<Tensor<float>> &input_data = inputs.at(i); //取出批次当中的一个张量
  27. //对张量中的每一个元素进行运算,进行relu运算
  28. input_data->data().transform([&](float value) {
  29. // 对张良中的没一个元素进行运算
  30. // 从operator中得到存储的属性
  31. float thresh = op_->get_thresh();
  32. //x >= thresh
  33. if (value >= thresh) {
  34. return value; // return x
  35. } else {
  36. // x<= thresh return 0.f;
  37. return 0.f;
  38. }
  39. });
  40. // 把结果y放在outputs中
  41. outputs.push_back(input_data);
  42. }
  43. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/79211
推荐阅读
相关标签
  

闽ICP备14008679号