当前位置:   article > 正文

Paddle:reshape算子与pytorch中的view算子_paddle view

paddle view

pytorch中的view算子 = paddle中的reshape算子

  1. # pytorch
  2. # AvgPool2d比AdaptiveAvgPool2d更快,但是使用View 和 Mean会比AvgPool2d快5倍.
  3. class FastGlobalAvgPool2d(nn.Module):
  4. def __init__(self, flatten=False):
  5. super(FastGlobalAvgPool2d, self).__init__()
  6. self.flatten = flatten
  7. def forward(self, x):
  8. if self.flatten:
  9. in_size = x.size()
  10. return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
  11. else:
  12. return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
  1. # paddle
  2. # AvgPool2d比AdaptiveAvgPool2d更快,但是使用View 和 Mean会比AvgPool2d快5倍.
  3. class FastGlobalAvgPool2d(nn.Layer):
  4. def __init__(self, flatten=False):
  5. super(FastGlobalAvgPool2d, self).__init__()
  6. self.flatten = flatten
  7. def forward(self, x):
  8. in_size = x.shape
  9. if self.flatten:
  10. return x.reshape((in_size[0], in_size[1], -1)).mean(axis=-1)
  11. else:
  12. return x.reshape(in_size[0], in_size[1], -1).mean(axis=-1).reshape(in_size[0], in_size[1], 1, 1)

此外:

1:需要注意mean的参数,pytorch中是dim指定维度,paddle中是axis指定维度

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/154159?site
推荐阅读
相关标签
  

闽ICP备14008679号