当前位置:   article > 正文

【Pytorch】实现subpixle上采样及下采样_pytorch 下采样 避免出现锯齿

pytorch 下采样 避免出现锯齿

Pytorch 实现subpixle上采样及下采样, 类似与tensorflow的tf.depth_to_space , tf.space_to_depth

def shuffle_down(inputs, scale):
    N, C, iH, iW = inputs.size()
    oH = iH // scale
    oW = iW // scale

    output = inputs.view(N, C, oH, scale, oW, scale)
    output = output.permute(0,1,5,3,2,4).contiguous()
    return output.view(N, -1, oH, oW)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
def shuffle_up(inputs, scale):
    N, C, iH, iW = inputs.size()
    oH = iH * scale
    oW = iW * scale
    oC = C // (scale ** 2)
    output = inputs.view(N, oC, scale, scale, iH, iW)
    output = output.permute(0,1,4,3,5,2).contiguous()
    output = output.view(N, oC, oH, oW)
    return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

==================
pytorch 论坛的例子

from torch import nn

class DepthToSpace(nn.Module):

    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)  # (N, bs, bs, C//bs^2, H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)
        return x


class SpaceToDepth(nn.Module):

    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs)  # (N, C, H//bs, bs, W//bs, bs)
        x = x.permute(0, 3, 5, 1, 2, 4).contiguous()  # (N, bs, bs, C, H//bs, W//bs)
        x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs)  # (N, C*bs^2, H//bs, W//bs)
        return x


import tensorflow as tf
import torch

# pytorch
x1 = torch.rand(64, 256, 8, 8)
x2 = DepthToSpace(2)(x1)
x3 = SpaceToDepth(2)(x2)
print(x1.size())
print(x2.size())
print(x3.size())
print((x1 == x3).all())

# tensorflow
y1 = tf.transpose(x1.numpy(), [0, 2, 3, 1])  # NCHW -> NHWC
y2 = tf.depth_to_space(y1, 2)
y3 = tf.space_to_depth(y2, 2)

y1 = tf.transpose(y1, [0, 3, 1, 2])  # NHWC -> NCHW
y2 = tf.transpose(y2, [0, 3, 1, 2])
y3 = tf.transpose(y3, [0, 3, 1, 2])

y1, y2, y3 = tf.Session().run([y1, y2, y3])
print(y1.shape)
print(y2.shape)
print(y3.shape)
print((y1 == y3).all())

# check consistency
print((x1.numpy() == y1).all())
print((x2.numpy() == y2).all())
print((x3.numpy() == y3).all())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/154273
推荐阅读
相关标签
  

闽ICP备14008679号