当前位置:   article > 正文

上采样和下采样层 nn.pixelshuffle and nn.pixelunshuffle_pixel unshuffle

pixel unshuffle

前言

理论部分后面有空的时候补一下,这里先放代码和简要说明。


Downsample

这里先对channel维度降低为原来 1 / 2 1/2 1/2,然后再对channel维度提升 r 2 r^2 r2 倍数, 这里 r = 2 r=2 r=2 (nn.PixelUnshuffle中的参数),同时特征图的宽度和高度都会缩减为原来的 1 / r 1/r 1/r

class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

给个例子吧.

ds = Downsample(16)
# up = Upsample(16)
x = torch.randn(1, 16, 64, 64)
x = ds(x)
# x = up(x)
print(x.shape)  # torch.Size([1, 32, 32, 32])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

这里 [ 1 , 16 , 64 , 64 ] [1, 16, 64, 64] [1,16,64,64] 经过 conv 层后变为 [ 1 , 8 , 64 , 64 ] [1, 8, 64, 64] [1,8,64,64], 然后经过 PixelUnshuffle 层 变为 [ 1 , 8 × 4 , 64 / 2 , 64 / 2 ] = [ 1 , 32 , 32 , 32 ] [1, 8 \times 4, 64 / 2, 64 / 2]=[1, 32, 32,32] [1,8×4,64/2,64/2]=[1,32,32,32]

Upsample

这里先对channel维度提升为原来的 2 2 2 倍,然后再对channel维度降低为原来的 1 / ( r 2 ) 1/(r^2) 1/(r2) , 这里 r = 2 r=2 r=2 (nn.PixelUnshuffle中的参数),同时特征图的宽度和高度都会扩大为原来的 2 2 2 倍。

class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

给个例子吧

# ds = Downsample(16)
up = Upsample(16)
x = torch.randn(1, 16, 64, 64)
# x = ds(x)
x = up(x)
print(x.shape)  # torch.Size([1, 8, 128, 128])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

这里 [ 1 , 16 , 64 , 64 ] [1, 16, 64, 64] [1,16,64,64] 经过 conv 层后变为 [ 1 , 32 , 64 , 64 ] [1, 32, 64, 64] [1,32,64,64], 然后经过 PixelUnshuffle 层 变为 [ 1 , 32 / 4 , 64 × 2 , 64 × 2 ] = [ 1 , 8 , 128 , 128 ] [1, 32/4, 64 \times 2, 64 \times 2]=[1, 8, 128, 128] [1,32/4,64×2,64×2]=[1,8,128,128]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/154230
推荐阅读
  

闽ICP备14008679号