当前位置:   article > 正文

自监督去噪:self2self 原理及实现(Pytorch)

self2self

Self2Self With Dropout: Learning Self-Supervised Denoising From Single Image

在这里插入图片描述

1. 原理简介

噪声图片 y 可以表示为 干净图片 x 和噪声 n的叠加
y = x + n y = x + n y=x+n

使用单个输入进行预测 的原理是:
F θ ( . )    :    y → x F_{\theta}(.) \; : \; y \rightarrow x Fθ(.):yx

常规监督神经网络训练
m i n θ ∑ i L ( F θ ( x ( i ) ) , y ( i ) ) \underset{\theta}{min} \sum_i L(F_{\theta}(x^{(i)}),y^{(i)}) θminiL(Fθ(x(i)),y(i))

其中 F θ F_{\theta} Fθ是神经网络, θ \theta θ是网络参数;但是就从一个神经网络训练的过程来看
M S E = b i a s 2 + v a r i a n c e MSE = bias ^2 + variance MSE=bias2+variance

当训练数据减少的时候,variance会极剧增加。blind-spot技术可以用来阻止这种过拟合现象,但单个样本训练带来的大的variance是无法解决的。这也是基于blind-spot的神经网络 N2V和N2S在单个图片上效果不好的原因。

Dropout技术是一种广泛应用的正则化技术,同时其可以提供一定程度的不确定性估计,避免出现恒等映射。盲点策略通过对噪声数据随机采样合成多个不同的噪声数据版本,并在这些替换样本上计算损失。因此本文提出的一个策略就变为了:在输入图像的伯努利采样实例上定义自预测损失函数
y ^ [ k ] = { y [ k ] , w i t h    p r o b a b i l i t y    p ; 0 , w i t h    p r o b a b i l i t y    1 − p \hat{y}[k] = {y[k],withprobabilityp;0,withprobability1p y^[k]={y[k]0,withprobabilityp;,withprobability1p

采样两个 Bernoulli 采样实例数据集 y ^ m {\hat{y}_m} y^m y n ^ \hat{y_n} yn^

  • 训练过程,最小化下面这个损失
    m i n θ ∑ m L ( F θ ( y ^ m ) , y − y ^ m ) \underset{\theta}{min} \sum_m L(F_{\theta}(\hat{y}_m),y-\hat{y}_m) θminmL(Fθ(y^m),yy^m)

  • 测试过程:在另一个采样数据集上, 得到每一个 y n y_n yn对应的预测结果,然后求一个平均值得到最后的去噪数据


2. 网络结构

在这里插入图片描述

  • Encoder结构

    • 输入大小 H × W × C H \times W \times C H×W×C
    • 使用 partial convolution layer(Pconv)将输入变为 H × W × 48 H \times W \times 48 H×W×48
    • 然后使用六个 encoder block(EBs):
      • 前五个包含 Pconv层,1个 Leakey ReLu激活函数,一个最大池化层(2*2感受野、stride为2)
      • 最后一层只有 Pconv层和 一个 Leakey ReLU激活函数
      • 通道固定为48
    • 编码器的输出为 H / 32 × W / 32 × 48 H/32 \times W/32 \times 48 H/32×W/32×48
  • Decoder 结构:

    • 包含五个decoder blocks
      • 前四个blcok每一个包含一个上采样参数为2的上采样层,一个concate操作,两个标准的Conv层和 Leakey Relu激活。concate操作是将上采样得到的结果进行了聚集。
      • 前四个block都有96个输出通道
    • 最后一个decoder block有三个dropout层,使用LeakeyReLU激活函数。最后将输出恢复为 H × W × C H \times W \times C H×W×C的大小

部分细节:

  • 所有的PConv层和Conv层都使用kernel size 3*3,strid = 1,padding = 2
  • Leakdy ReLU的斜率为 0.1
  • droupouts的概率为0.3
  • bernoulli sampling的概率为 0.3
  • 使用Adam优化器,学习率 1 0 − 5 10^{-5} 105,迭代450000次

结构和 Noise2Noise结构基本相似,不同点在于:

  • 在Decoder中加入了dropout (不确定性估计和稳定性)
  • 在Encoder中使用部分卷积替代标准卷积

3. Pytorch实现

(1)Partial convolution 结构

注意,这里是使用的 部分卷积网络,所以使用了 NVIDIA的实现,

import torch
import torch.nn.functional as F
from torch import nn, cuda
from torch.autograd import Variable

class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False  

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        #####Yize's fixes
        self.multi_channel = True
        self.return_mask = True
        
        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
            
        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)

                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                else:
                    mask = mask_in
                        
                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)


        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)


        if self.return_mask:
            return output, self.update_mask
        else:
            return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
(2) U-net 网络结构
class EncodeBlock(nn.Module):
    def __init__(self,in_channel,out_channel,flag):
        super(EncodeBlock,self).__init__()
        self.conv = PartialConv2d(in_channel, out_channel, kernel_size = 3, padding = 1)
        self.nonlinear = nn.LeakyReLU(0.1)
        self.MaxPool = nn.MaxPool2d(2)
        self.flag = flag
    
    def forward(self, x, mask_in):
        out1, mask_out = self.conv(x, mask_in = mask_in)
        out2 = self.nonlinear(out1)
        if self.flag:
            out = self.MaxPool(out2)
            mask_out = self.MaxPool(mask_out)
        else:
            out = out2
        return out, mask_out
    
class DecodeBlock(nn.Module):
    def __init__(self, in_channel, mid_channel, out_channel, final_channel = 3, p = 0.7, flag = False):
        super(DecodeBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_channel,mid_channel,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(mid_channel,out_channel,kernel_size=3,padding=1)
        self.conv3 = nn.Conv2d(out_channel,final_channel,kernel_size=3,padding=1)
        self.nonlinear1 = nn.LeakyReLU(0.1)
        self.nonlinear2 = nn.LeakyReLU(0.1)
        self.sigmoid = nn.Sigmoid()
        self.flag = flag
        self.Dropout = nn.Dropout(p)
    
    def forward(self,x):
        out1 = self.conv1(self.Dropout(x))
        out2 = self.nonlinear1(out1)
        out3 = self.conv2(self.Dropout(out2))
        out4 = self.nonlinear2(out3)
        if self.flag:
            out5 = self.conv3(self.Dropout(out4))
            out = self.sigmoid(out5)
        else:
            out = out4
        return out
        
class self2self(nn.Module):
    def __init__(self,in_channel,p):
        super(self2self,self).__init__()
        self.EB0 = EncodeBlock(in_channel,out_channel=48,flag=False)
        self.EB1 = EncodeBlock(48,48,flag=True)
        self.EB2 = EncodeBlock(48,48,flag=True)
        self.EB3 = EncodeBlock(48,48,flag=True)
        self.EB4 = EncodeBlock(48,48,flag=True)
        self.EB5 = EncodeBlock(48,48,flag=True)
        self.EB6 = EncodeBlock(48,48,flag=False)
        
        self.DB1 = DecodeBlock(in_channel=96,mid_channel=96,out_channel=96,p=p)
        self.DB2 = DecodeBlock(in_channel=144,mid_channel=96,out_channel=96,p=p)
        self.DB3 = DecodeBlock(in_channel=144,mid_channel=96,out_channel=96,p=p)
        self.DB4 = DecodeBlock(in_channel=144,mid_channel=96,out_channel=96,p=p)
        self.DB5 = DecodeBlock(in_channel=96+in_channel,mid_channel=64,out_channel=32,p=p,flag=True)
        
        self.Upsample = nn.Upsample(scale_factor=2,mode='bilinear')
        self.concat_dim = 1
    
    def forward(self,x,mask):
        out_EB0,mask = self.EB0(x,mask)                 # [3,w,h]        ->     [48,w,h]
        out_EB1,mask = self.EB1(out_EB0,mask_in=mask)   # [48,w,h]       ->     [48,w/2,h/2]
        out_EB2,mask = self.EB2(out_EB1,mask_in=mask)   # [48,w/2,h/2]   ->     [48,w/4,h/4]
        out_EB3,mask = self.EB3(out_EB2,mask_in=mask)   # [48,w/4,h/4]   ->     [48,w/8,h/8]
        out_EB4,mask = self.EB4(out_EB3,mask_in=mask)   # [48,w/8,h/8]   ->     [48,w/16,h/16]
        out_EB5,mask = self.EB5(out_EB4,mask_in=mask)   # [48,w/16,h/16] ->     [48,w/32,h/32]
        out_EB6,mask = self.EB6(out_EB5,mask_in=mask)   # [48,w/32,h/32] ->     [48,w/32,h/32]
        
        out_EB6_up = self.Upsample(out_EB6)             # [48,w/32,h/32] ->     [48,w/16,h/16]
        in_DB1 = torch.cat((out_EB6_up,out_EB4),self.concat_dim) # [48,w/16,h/16] -> [96,w/16,h/16]
        out_DB1 = self.DB1((in_DB1))                    # [96,w/16,h/16] ->     [96,w/16,h/16]
        
        out_DB1_up = self.Upsample(out_DB1)             # [96,w/16,h/16] ->     [96,w/8,h/8]
        in_DB2 = torch.cat((out_DB1_up,out_EB3),self.concat_dim) # [96,w/8,w/8] -> [144,w/8,w/8]
        out_DB2 = self.DB2((in_DB2))                    # [144,w/8,w/8] -> [96,w/8,w/8]
        
        out_DB2_up = self.Upsample(out_DB2)             # [96,w/8,h/8] ->     [96,w/4,h/4]
        in_DB3 = torch.cat((out_DB2_up,out_EB2),self.concat_dim) # [96,w/4,w/4] -> [144,w/4,w/4]
        out_DB3 = self.DB2((in_DB3))                    # [144,w/4,w/4] -> [96,w/4,w/4]
        
        out_DB3_up = self.Upsample(out_DB3)             # [96,w/4,h/4] ->     [96,w/2,h/2]
        in_DB4 = torch.cat((out_DB3_up, out_EB1),self.concat_dim) # [96,w/2,w/2] -> [144,w/2,w/2]
        out_DB4 = self.DB4((in_DB4))                    # [144,w/2,w/2] -> [96,w/2,w/2]
        
        out_DB4_up = self.Upsample(out_DB4)             # [96,w/2,h/2] ->     [96,w,h]
        in_DB5 = torch.cat((out_DB4_up, x),self.concat_dim) # [96,w,h] ->     [96+c,w,h]
        out_DB5 = self.DB5(in_DB5)                      # [96+c,w,h] ->     [32,w,h]
        return out_DB5
    
model = self2self(3,0.3)
model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
(3)网络训练
import numpy as np 
import matplotlib.pyplot as plt
import torch.optim as optim
import torchvision.transforms as T
import cv2 
from PIL import Image
from tqdm import tqdm

# 图片加载
img = np.array(Image.open("5.png"))

plt.figure()
plt.imshow(img)
plt.show()
img.shape
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

在这里插入图片描述

# 参数设置
##Enable GPU
USE_GPU = True

dtype = torch.float32

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

learning_rate = 1e-4
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
w,h,c = img.shape
p=0.3
NPred=100
slice_avg = torch.tensor([1,3,512,512]).to(device)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
# 训练迭代
def image_loader(image, device, p1, p2):
    """
        load image and returns cuda tensor
    """
    loader = T.Compose([
            T.RandomHorizontalFlip(torch.round(torch.tensor(p1))),
            T.RandomVerticalFlip(torch.round(torch.tensor(p2))),
            T.ToTensor()])
    image = Image.fromarray(image.astype(np.uint8))
    image = loader(image).float()
    if not torch.is_tensor(image):
        image = torch.tensor(image)
    image = image.unsqueeze(0)  #this is for VGG, may not be needed for ResNet
    return image.to(device)

pbar = tqdm(range(500000))
for itr in pbar:
    # 不知道这个采样是否正确,是不是需要在每一个通道都分别进行均匀采样?
    p_mtx = np.random.uniform(size=[img.shape[0],img.shape[1],img.shape[2]])
    mask = (p_mtx>p).astype(np.double)
    img_input = img
    
    y = img
    p1 = np.random.uniform(size=1)
    p2 = np.random.uniform(size=1)
    # 加载输入图片(根据概率进行翻转)
    img_input_tensor = image_loader(img_input, device, p1, p2)
    
    # 对原始图片进行相同操作(翻转)
    y = image_loader(y, device, p1, p2)
    
    # mask为伯努利采样结果
    mask = np.expand_dims(np.transpose(mask,[2,0,1]),0)
    mask = torch.tensor(mask).to(device, dtype=torch.float32)

    # 网络推理
    model.train()
    img_input_tensor = img_input_tensor*mask
    output = model(img_input_tensor, mask)

    # 损失函数
    # loss = torch.sum((output+img_input_tensor-y)*(output+img_input_tensor-y)*(1-mask))/torch.sum(1-mask)
    loss = torch.sum((output-y)*(output-y)*(1-mask))/torch.sum(1-mask)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    pbar.set_description("iteration {}, loss = {:.4f}".format(itr+1, loss.item()*100))

    if (itr+1)%1000 == 0:
        model.eval()
        sum_preds = np.zeros((img.shape[0],img.shape[1],img.shape[2]))
        for j in range(NPred):
            p_mtx = np.random.uniform(size=img.shape)
            mask = (p_mtx>p).astype(np.double)
            img_input = img*mask
            img_input_tensor = image_loader(img_input, device, 0.1, 0.1)
            mask = np.expand_dims(np.transpose(mask,[2,0,1]),0)
            mask = torch.tensor(mask).to(device, dtype=torch.float32)
            
            output_test = model(img_input_tensor,mask)
            sum_preds[:,:,:] += np.transpose(output_test.detach().cpu().numpy(),[2,3,1,0])[:,:,:,0]
        avg_preds = np.squeeze(np.uint8(np.clip((sum_preds-np.min(sum_preds)) / (np.max(sum_preds)-np.min(sum_preds)), 0, 1) * 255))
        write_img = Image.fromarray(avg_preds)
        write_img.save("./examples/images/Self2self-"+str(itr+1)+".png")
        torch.save(model.state_dict(),'./examples/models/model-'+str(itr+1))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66

展示一下这里进行伯努利采样得到的结果和输入的噪声图片的区别
在这里插入图片描述

(4)迭代结果

展示不同次数的结果:
1000,10000,20000,30000次迭代

总结

从我自己可能会用到的地方进行 评价 (不是评价啊哈,大佬的工作真的非常棒,就是从我们迁移应用的角度看待)

  • 单样本任务,不需要合成特别多的样本
  • 使用Dropout引入了模型的不确定性估计,可以使得恢复更加稳定
  • 使用部分卷积替代常规卷积,对于图片去噪和恢复有一定的效果
  • 和Deep Image Prior相比,二者都不需要多余的样本,但是self2self更加稳定

一些小问题:

  • 迭代次数太多,上述操作迭代了500000次
  • 如果一张照片去噪需要1小时,那么其应用场景比较有限
  • 其实损失函数的设计,对该方法有一定的影响,可以尝试一下不同的损失函数,其结果会有一定的影响
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/207049
推荐阅读
相关标签
  

闽ICP备14008679号