赞
踩
Roll算子一般被用再Swin结构中,Paddle2ONNX暂时不支持该算子,本教程介绍如何为Paddle2ONNX添加roll算子。
paddle.roll(x, shifts, axis=None, name=None)
- x (Tensor)– 输入的 Tensor。
- shifts (int|list|tuple) - 滚动位移。如果 shifts 是一个元组或者列表,则 axis 必须是相同大小的元组或者列表,输入 Tensor 将依次沿着每个维度滚动相应的数值。
- axis (int|list|tuple,可选) – 滚动轴。默认值为 None。
- name (str,可选) - 具体用法请参见 Name,一般无需设置,默认值为 None。
沿着指定维度 axis 对输入 x 进行循环滚动,当元素移动到最后位置时,会从第一个位置重新插入。如果 axis 为 None,则输入在被循环滚动之前,会先展平成 1-D Tensor,滚动操作完成后恢复成原来的形状。
#pragma once
#include <string>
#include <vector>
#include "paddle2onnx/mapper/mapper.h"
namespace paddle2onnx {
class RollMapper : public Mapper {
public:
RollMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
void Opset7();
};
} // namespace paddle2onnx
#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"
namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)
}
#include <limits> #include "paddle2onnx/mapper/tensor/roll.h" namespace paddle2onnx { void RollMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); std::vector<int64_t> shifts; GetAttr("shifts", &shifts); std::vector<int64_t> axis; GetAttr("axis", &axis); std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr; auto result_name = input_info[0].name; if (axis.empty()) { int64_t axes = 0; result_name = helper_->Flatten(result_name); for(int i = 0;i < shifts.size();i++) { auto shift = shifts[i]; auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()}); auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift}); temp_node = helper_->MakeNode("Concat", {result_0, result_1}); AddAttribute(temp_node, "axis", axes); result_name = temp_node->output(0); } helper_->Reshape(result_name, output_info[0].name, input_info[0].shape); // helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name}); } } }
#include <limits> #include "paddle2onnx/mapper/tensor/roll.h" namespace paddle2onnx { void RollMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); std::vector<int64_t> shifts; GetAttr("shifts", &shifts); std::vector<int64_t> axis; GetAttr("axis", &axis); std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr; auto result_name = input_info[0].name; if (axis.empty()) { } else { for(int i = 0;i < shifts.size();i++) { auto shift = shifts[i]; int64_t axes = axis[i]; auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()}); auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift}); if(i+1 == shifts.size()) { temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name}); } else { temp_node = helper_->MakeNode("Concat", {result_0, result_1}); } AddAttribute(temp_node, "axis", axes); result_name = temp_node->output(0); } } } }
#include <limits> #include "paddle2onnx/mapper/tensor/roll.h" namespace paddle2onnx { REGISTER_MAPPER(roll, RollMapper) void RollMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); std::vector<int64_t> shifts; GetAttr("shifts", &shifts); std::vector<int64_t> axis; GetAttr("axis", &axis); std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr; auto result_name = input_info[0].name; if (axis.empty()) { int64_t axes = 0; result_name = helper_->Flatten(result_name); for(int i = 0;i < shifts.size();i++) { auto shift = shifts[i]; auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()}); auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift}); temp_node = helper_->MakeNode("Concat", {result_0, result_1}); AddAttribute(temp_node, "axis", axes); result_name = temp_node->output(0); } helper_->Reshape(result_name, output_info[0].name, input_info[0].shape); // helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name}); } else { for(int i = 0;i < shifts.size();i++) { auto shift = shifts[i]; int64_t axes = axis[i]; auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()}); auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift}); if(i+1 == shifts.size()) { temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name}); } else { temp_node = helper_->MakeNode("Concat", {result_0, result_1}); } AddAttribute(temp_node, "axis", axes); result_name = temp_node->output(0); } } } } // namespace paddle2onnx
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。