赞
踩
在PaddlePaddle2.6中,swish算子在PaddleInference上发生了变化,删除掉了beta这个Attr,因此我们需要想办法自行适配它。
原解析relu6算子的核心代码如下:
void SwishMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
std::string beta_node =
helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), beta_);
// TODO(jiangjiajun) eliminate multiply with a constant of value 1
// TODO(jiangjiajun) eliminate add with a constant of value 0
auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
auto sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)});
helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
{output_info[0].name});
}
如果仅需要适配PaddlePaddle2.6,只需要改动为(同时还需要在类的构造函数中删除对beta参数的读取):
void SwishMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
auto sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name});
helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)},
{output_info[0].name});
}
考虑到要兼容PaddlePaddle2.5之前的用户,因此不能直接删除掉beta这个参数,进一步修改如下:
void SwishMapper::Opset7() { auto input_info = GetInput("X"); auto output_info = GetOutput("Out"); std::shared_ptr<paddle2onnx::NodeProto> sigmod_node = nullptr; if (HasAttr("beta")) { float temp_beta = 1.0; GetAttr("beta", &temp_beta); std::string beta_node = helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), temp_beta); auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node}); sigmod_node = helper_->MakeNode("Sigmoid", {beta_x_node->output(0)}); } else { sigmod_node = helper_->MakeNode("Sigmoid", {input_info[0].name}); } helper_->MakeNode("Mul", {input_info[0].name, sigmod_node->output(0)}, {output_info[0].name}); }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。