赞
踩
Pytorch 实现subpixle上采样及下采样, 类似与tensorflow的tf.depth_to_space , tf.space_to_depth
def shuffle_down(inputs, scale):
N, C, iH, iW = inputs.size()
oH = iH // scale
oW = iW // scale
output = inputs.view(N, C, oH, scale, oW, scale)
output = output.permute(0,1,5,3,2,4).contiguous()
return output.view(N, -1, oH, oW)
def shuffle_up(inputs, scale):
N, C, iH, iW = inputs.size()
oH = iH * scale
oW = iW * scale
oC = C // (scale ** 2)
output = inputs.view(N, oC, scale, scale, iH, iW)
output = output.permute(0,1,4,3,5,2).contiguous()
output = output.view(N, oC, oH, oW)
return output
==================
pytorch 论坛的例子
from torch import nn class DepthToSpace(nn.Module): def __init__(self, block_size): super().__init__() self.bs = block_size def forward(self, x): N, C, H, W = x.size() x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) return x class SpaceToDepth(nn.Module): def __init__(self, block_size): super().__init__() self.bs = block_size def forward(self, x): N, C, H, W = x.size() x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) return x import tensorflow as tf import torch # pytorch x1 = torch.rand(64, 256, 8, 8) x2 = DepthToSpace(2)(x1) x3 = SpaceToDepth(2)(x2) print(x1.size()) print(x2.size()) print(x3.size()) print((x1 == x3).all()) # tensorflow y1 = tf.transpose(x1.numpy(), [0, 2, 3, 1]) # NCHW -> NHWC y2 = tf.depth_to_space(y1, 2) y3 = tf.space_to_depth(y2, 2) y1 = tf.transpose(y1, [0, 3, 1, 2]) # NHWC -> NCHW y2 = tf.transpose(y2, [0, 3, 1, 2]) y3 = tf.transpose(y3, [0, 3, 1, 2]) y1, y2, y3 = tf.Session().run([y1, y2, y3]) print(y1.shape) print(y2.shape) print(y3.shape) print((y1 == y3).all()) # check consistency print((x1.numpy() == y1).all()) print((x2.numpy() == y2).all()) print((x3.numpy() == y3).all())
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。