当前位置:   article > 正文

pytorch的pixel_shuffle转tflite文件_pixel unshuffle 转tflite

pixel unshuffle 转tflite

torch.pixel_shuffle()是pytorch里面上采样比较常用的方法,但是和tensoflow的depth_to_space不是完全一样的,虽然看起来功能很像,但是细微是有差异的

  1. def tf_pixelshuffle(input, upscale_factor):
  2. temp = []
  3. depth = upscale_factor *upscale_factor
  4. channels = input.shape.as_list()[-1] // depth
  5. for i in range(channels):
  6. out_ = tf.nn.depth_to_space(input=input[:,:, :,i*depth:(i+1)*depth], block_size=upscale_factor)
  7. temp.append(out_)
  8. out = tf.concat(temp, axis=-1)
  9. return out

因为,有人发现在单通道的时候是depth_to_space和pixel_shuffle结果是一样的,所以拆分出来计算好在合并就行,这样速度基本上没有增加多少,亲测速度也是很快的,比从头开始实现pixel_shuffle是快非常多的。

如果使用这样的从头开始实现,转出来的tflite是没法运行在手机上面的,因为tf.transpose的维度太多了,tflite在手机上不支持6个维度的transpose的,因为超过5个维度就会产生flex层,flex层是不被支持的。

  1. def pixel_shuffle(x, upscale_factor):
  2. batch_size, height, width, channels = x.shape
  3. channel_split = channels // (upscale_factor ** 2)
  4. # Reshape the input tensor to split channels
  5. x = tf.reshape(x, (batch_size, height, width, upscale_factor, upscale_factor, channel_split))
  6. # Transpose and reshape to get the pixel shuffled output
  7. x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
  8. x = tf.reshape(x, (batch_size, height * upscale_factor, width * upscale_factor, channel_split))
  9. return x

下面就测试一下:

新建pytorch模型

  1. import torch
  2. import torch.nn as nn
  3. class Net(nn.Module):
  4. def __init__(self):
  5. super().__init__()
  6. self.conv=nn.Conv2d(in_channels=3,
  7. out_channels=12,
  8. kernel_size=3,
  9. stride=2,
  10. padding=1)
  11. def forward(self, input):
  12. x=self.conv(input)
  13. out=torch.pixel_shuffle(x,2)
  14. return out

可视化出来

利用tf_pixelshuffle转出来的结果:

利用pixel_shuffle转出来的结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/154240?site
推荐阅读
相关标签
  

闽ICP备14008679号