当前位置:   article > 正文

【Paddle2ONNX】为Paddle2ONNX适配roll算子

paddle2onnx

1 简介

Roll算子一般被用再Swin结构中,Paddle2ONNX暂时不支持该算子,本教程介绍如何为Paddle2ONNX添加roll算子。

2 实现过程

2.1 Roll算子简介

paddle.roll(x, shifts, axis=None, name=None)

  • x (Tensor)– 输入的 Tensor。
  • shifts (int|list|tuple) - 滚动位移。如果 shifts 是一个元组或者列表,则 axis 必须是相同大小的元组或者列表,输入 Tensor 将依次沿着每个维度滚动相应的数值。
  • axis (int|list|tuple,可选) – 滚动轴。默认值为 None。
  • name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。

沿着指定维度 axis 对输入 x 进行循环滚动,当元素移动到最后位置时,会从第一个位置重新插入。如果 axis 为 None,则输入在被循环滚动之前,会先展平成 1-D Tensor,滚动操作完成后恢复成原来的形状。

2.2 在Paddle2ONNX中实现roll算子

  1. 首先在paddle2onnx/mapper/tensor下新建roll.hroll.cpp,并在roll.h中添加对Roll算子的定义
#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {
class RollMapper : public Mapper {
    public:
    RollMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
    int64_t op_id)
    : Mapper(p, helper, block_id, op_id) {}
    void Opset7();
};
}  // namespace paddle2onnx
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  1. 注册Roll算子
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  1. 添加对axis为None情况下的Roll算子
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
void RollMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");

  std::vector<int64_t> shifts;
  GetAttr("shifts", &shifts);
  std::vector<int64_t> axis;
  GetAttr("axis", &axis);

  std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
  auto result_name = input_info[0].name;
  if (axis.empty())
  {
    int64_t axes = 0;
    result_name = helper_->Flatten(result_name);
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
    helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
    // helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
  }
}
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  1. 添加对axis不为None的实现
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
void RollMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");

  std::vector<int64_t> shifts;
  GetAttr("shifts", &shifts);
  std::vector<int64_t> axis;
  GetAttr("axis", &axis);

  std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
  auto result_name = input_info[0].name;
  if (axis.empty())
  {
  } else {
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      int64_t axes = axis[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      if(i+1 == shifts.size()) {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
      } else {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      }
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
  }
}
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  1. 合并后核心代码如下:
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)

void RollMapper::Opset7() {
  auto input_info = GetInput("X");
  auto output_info = GetOutput("Out");

  std::vector<int64_t> shifts;
  GetAttr("shifts", &shifts);

  std::vector<int64_t> axis;
  GetAttr("axis", &axis);

  std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
  auto result_name = input_info[0].name;
  if (axis.empty())
  {
    int64_t axes = 0;
    result_name = helper_->Flatten(result_name);
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
    helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
    // helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
  } else {
    for(int i = 0;i < shifts.size();i++) {
      auto shift = shifts[i];
      int64_t axes = axis[i];
      auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
      auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
      if(i+1 == shifts.size()) {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
      } else {
        temp_node = helper_->MakeNode("Concat", {result_0, result_1});
      }
      AddAttribute(temp_node, "axis", axes);
      result_name = temp_node->output(0);
    }
  }
}
}  // namespace paddle2onnx
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

3 参考资料

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/670469
推荐阅读
相关标签
  

闽ICP备14008679号