当前位置:   article > 正文

无监督低照度图像增强网络ZeroDCE和SCI介绍

低照度图像增强网络

目录

 

简介

Zero-DCE

算法介绍

 模型代码

无监督loss介绍 

小结

 Self-Calibrated Illumination (SCI)

模型介绍

无监督loss介绍

 小结

总结


 

简介

        当前有较多深度学习的方法来做图像效果增强,但多数都是有监督的,需要构造成对的数据,实际使用中,获取成对的数据更难。这里分享两篇无监督的图像增强方法,使用了深度学习网络。

Zero-DCE

论文名称:Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement

论文地址:https://openaccess.thecvf.com/content_CVPR_2020/papers/Guo_Zero-Reference_Deep_Curve_Estimation_for_Low-Light_Image_Enhancement_CVPR_2020_paper.pdf

代码地址:GitHub - Li-Chongyi/Zero-DCE: Zero-DCE code and model

算法介绍

b6cafa1ae8114162a058f1db87a6a401.png

        Zero-DCE算法使用了如下公式来做图像增强,通过网络学出An(x),An(x)是和图像分辨率大小一致的,且有RGB三个通道,其中n=8。通过多次迭代来计算出最后的RGB图,其实当图分辨率较大时,计算量是很大的,很耗时的,flops很大,256x256分辨率的图像,flops都超过5G了。

        网络本身参数量是比较小的,只是用了7层3x3的卷积网络,中间通道数为32,最后输出24通道的An。

ef0f882ccc2a4cb49dfb42bb2268380a.png 

 模型代码

        模型代码如下

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5. #import pytorch_colors as colors
  6. import numpy as np
  7. class enhance_net_nopool(nn.Module):
  8. def __init__(self):
  9. super(enhance_net_nopool, self).__init__()
  10. self.relu = nn.ReLU(inplace=True)
  11. number_f = 32
  12. self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
  13. self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
  14. self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
  15. self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
  16. self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
  17. self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
  18. self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)
  19. self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
  20. self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
  21. def forward(self, x):
  22. x1 = self.relu(self.e_conv1(x))
  23. # p1 = self.maxpool(x1)
  24. x2 = self.relu(self.e_conv2(x1))
  25. # p2 = self.maxpool(x2)
  26. x3 = self.relu(self.e_conv3(x2))
  27. # p3 = self.maxpool(x3)
  28. x4 = self.relu(self.e_conv4(x3))
  29. x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
  30. # x5 = self.upsample(x5)
  31. x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
  32. x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
  33. r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
  34. x = x + r1*(torch.pow(x,2)-x)
  35. x = x + r2*(torch.pow(x,2)-x)
  36. x = x + r3*(torch.pow(x,2)-x)
  37. enhance_image_1 = x + r4*(torch.pow(x,2)-x)
  38. x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
  39. x = x + r6*(torch.pow(x,2)-x)
  40. x = x + r7*(torch.pow(x,2)-x)
  41. enhance_image = x + r8*(torch.pow(x,2)-x)
  42. r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
  43. return enhance_image_1,enhance_image,r

无监督loss介绍 

        论文使用了四种loss,Spatial Consistency Loss,主要是保证增强后的图像,空间一致性和原图尽量保持一致。计算4x4邻域里的均值,然后对均值计算邻域梯度,包括上下左右四个方向,尽量保证原图和增强后的图像,这种空间梯度一致性。

606bbbdcebc84971b17add36659aadc7.png

         Spatial Consistency Loss的代码如下

  1. class L_spa(nn.Module):
  2. def __init__(self):
  3. super(L_spa, self).__init__()
  4. # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
  5. kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
  6. kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
  7. kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
  8. kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
  9. self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
  10. self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
  11. self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
  12. self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
  13. self.pool = nn.AvgPool2d(4)
  14. def forward(self, org , enhance ):
  15. b,c,h,w = org.shape
  16. org_mean = torch.mean(org,1,keepdim=True)
  17. enhance_mean = torch.mean(enhance,1,keepdim=True)
  18. org_pool = self.pool(org_mean)
  19. enhance_pool = self.pool(enhance_mean)
  20. weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
  21. E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
  22. D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
  23. D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
  24. D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
  25. D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
  26. D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
  27. D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
  28. D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
  29. D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
  30. D_left = torch.pow(D_org_letf - D_enhance_letf,2)
  31. D_right = torch.pow(D_org_right - D_enhance_right,2)
  32. D_up = torch.pow(D_org_up - D_enhance_up,2)
  33. D_down = torch.pow(D_org_down - D_enhance_down,2)
  34. E = (D_left + D_right + D_up +D_down)
  35. # E = 25*(D_left + D_right + D_up +D_down)
  36. return E

        Exposure Control Loss是为了保证图像的亮度在合理的范围,其中E为0.4~0.7,作者验证在这个范围内效果都差异不大,默认为0.6,相关代码如下。

02d52626bd6746a4b9440ff995f33a19.png

  1. class L_exp(nn.Module):
  2. def __init__(self,patch_size,mean_val):
  3. super(L_exp, self).__init__()
  4. # print(1)
  5. self.pool = nn.AvgPool2d(patch_size)
  6. self.mean_val = mean_val
  7. def forward(self, x ):
  8. b,c,h,w = x.shape
  9. x = torch.mean(x,1,keepdim=True)
  10. mean = self.pool(x)
  11. d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
  12. return d

        Color Constancy Loss是为了保证颜色的准确性,依据是灰度世界,其实也是有些缺陷的,但这个loss对颜色作用还挺大。就是RG的均值尽量相等,RB的均值也尽量相等,GB的均值也尽量相等,代码如下。

e04b690861d54a05900e926d31f826a7.png

  1. class Sa_Loss(nn.Module):
  2. def __init__(self):
  3. super(Sa_Loss, self).__init__()
  4. # print(1)
  5. def forward(self, x ):
  6. # self.grad = np.ones(x.shape,dtype=np.float32)
  7. b,c,h,w = x.shape
  8. # x_de = x.cpu().detach().numpy()
  9. r,g,b = torch.split(x , 1, dim=1)
  10. mean_rgb = torch.mean(x,[2,3],keepdim=True)
  11. mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
  12. Dr = r-mr
  13. Dg = g-mg
  14. Db = b-mb
  15. k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
  16. # print(k)
  17. k = torch.mean(k)
  18. return k

 Illumination Smoothness Loss是为了保证学出来的An具有一定的平滑性,即邻域的梯度尽量小,相关代码如下。

a13cee0bc89a4326bb4039c807f1fdaa.png

  1. class L_TV(nn.Module):
  2. def __init__(self,TVLoss_weight=1):
  3. super(L_TV,self).__init__()
  4. self.TVLoss_weight = TVLoss_weight
  5. def forward(self,x):
  6. batch_size = x.size()[0]
  7. h_x = x.size()[2]
  8. w_x = x.size()[3]
  9. count_h = (x.size()[2]-1) * x.size()[3]
  10. count_w = x.size()[2] * (x.size()[3] - 1)
  11. h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
  12. w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
  13. return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

小结

        这篇论文是对低照度图像处理的,没有实际跑过效果,就不对效果进行评价了,从loss来看,这个loss都挺合理的,估计 Exposure Control Loss作用挺大,把图像拉到了亮度适中的范围,主要不足是计算量太大了,在实时性场景基本上没法用。

 Self-Calibrated Illumination (SCI)

论文名称:Toward Fast, Flexible, and Robust Low-Light Image Enhancement

论文地址:https://arxiv.org/abs/2204.10137

代码地址:https://github.com/vis-opt-group/SCI

模型介绍

88c8440eacd44313b02f68539a99d804.png

         论文是基于Retinex theory,认为只要估计出光照x,即可得到干净真实的图像z,网络参数H就是为了估计光照x的,通过多个模块级联来学出H,可以认为每次学出部分光照,H是权值共享的,但最终是希望一次就准确算出光照x,推理阶段只用到了H,没用到K,且只有一次。增强网络很轻量,只使用了3层3x3的卷积层,不过代码里可以进行扩展为更深一些的网络,不过网络更深,那么计算量也就越大,比较遗憾的是论文没做消融试验来验证为啥使用3层卷积即可。

d19582f0530f45d5b51d12db5a032afb.png

f57b3738494e454481c01fa5e4250fc7.png

29363c087f434ef48c160daba0f486a1.png

  1. class EnhanceNetwork(nn.Module):
  2. def __init__(self, layers, channels):
  3. super(EnhanceNetwork, self).__init__()
  4. kernel_size = 3
  5. dilation = 1
  6. padding = int((kernel_size - 1) / 2) * dilation
  7. self.in_conv = nn.Sequential(
  8. nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
  9. nn.ReLU()
  10. )
  11. self.conv = nn.Sequential(
  12. nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
  13. nn.BatchNorm2d(channels),
  14. nn.ReLU()
  15. )
  16. self.blocks = nn.ModuleList()
  17. for i in range(layers):
  18. self.blocks.append(self.conv)
  19. self.out_conv = nn.Sequential(
  20. nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
  21. nn.Sigmoid()
  22. )
  23. def forward(self, input):
  24. fea = self.in_conv(input)
  25. for conv in self.blocks:
  26. fea = fea + conv(fea)
  27. fea = self.out_conv(fea)
  28. illu = fea + input
  29. illu = torch.clamp(illu, 0.0001, 1)
  30. return illu

        自校正模块,只是辅助功能,有利于收敛,如果没有这个参数,估计H很难学,可能容易把H学为0。利用F和G两个网络级联和权值共享,其实级联的次数也不用多,2-3次即可。自校正网络也不大,代码如下。

  1. class CalibrateNetwork(nn.Module):
  2. def __init__(self, layers, channels):
  3. super(CalibrateNetwork, self).__init__()
  4. kernel_size = 3
  5. dilation = 1
  6. padding = int((kernel_size - 1) / 2) * dilation
  7. self.layers = layers
  8. self.in_conv = nn.Sequential(
  9. nn.Conv2d(in_channels=3, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
  10. nn.BatchNorm2d(channels),
  11. nn.ReLU()
  12. )
  13. self.convs = nn.Sequential(
  14. nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
  15. nn.BatchNorm2d(channels),
  16. nn.ReLU(),
  17. nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
  18. nn.BatchNorm2d(channels),
  19. nn.ReLU()
  20. )
  21. self.blocks = nn.ModuleList()
  22. for i in range(layers):
  23. self.blocks.append(self.convs)
  24. self.out_conv = nn.Sequential(
  25. nn.Conv2d(in_channels=channels, out_channels=3, kernel_size=3, stride=1, padding=1),
  26. nn.Sigmoid()
  27. )
  28. def forward(self, input):
  29. fea = self.in_conv(input)
  30. for conv in self.blocks:
  31. fea = fea + conv(fea)
  32. fea = self.out_conv(fea)
  33. delta = input - fea
  34. return delta

无监督loss介绍

        loss上就使用了两种loss,主要是fifidelity loss,另外的smoothing loss只是为了保证空间平滑性。

11df5a635c694b0da5f5d5ac8f7dddee.png

894d41a85a3844de95ba3a083577c79c.png

        第一次的F(y)就把光照估计出来,那么第二次光照估计出来的残差就为0,因为此时认为图像中已经没有光照的影响了,已经是干净的图了,估计不出来光照了,这个也是fifidelity loss的来源。总之,就是利用权值共享,不断的来估计光照,但loss上是只有第一次可以估计出光照,后面都估计不出光照了,所以是t>=1之后,光照就估计不出来了。

        这个和直方图均衡有点类似的意思,就是如果直方图已经绝对均衡了,再做直方图均衡,效果也不会变了,当然,实际上,直方图不可能做到绝对均衡,所以级联去做直方图均衡效果还是会变换,但也只是振荡。这里也是,loss不可能为0,只是尽可能接近0,也就是希望一次估计出光照,对没有光照的图,就尽量估计不出来光照了。 

        学出来的效果有点像全局tonemapping的效果,主要是因为网络才3层3x3的卷积,基本上没有局部性,如果网络很深,那么计算量就会暴涨,当然,可以使用类似HDRnet的网络来学x,那样就会有局部性。另外,不知道白天场景效果会如何,是不是白天图像效果不怎样,作者才强调夜景图片。

        smooth loss的代码如下,写的有点复杂,没有细致研究了。

  1. class SmoothLoss(nn.Module):
  2. def __init__(self):
  3. super(SmoothLoss, self).__init__()
  4. self.sigma = 10
  5. def rgb2yCbCr(self, input_im):
  6. im_flat = input_im.contiguous().view(-1, 3).float()
  7. mat = torch.Tensor([[0.257, -0.148, 0.439], [0.564, -0.291, -0.368], [0.098, 0.439, -0.071]]).cuda()
  8. bias = torch.Tensor([16.0 / 255.0, 128.0 / 255.0, 128.0 / 255.0]).cuda()
  9. temp = im_flat.mm(mat) + bias
  10. out = temp.view(input_im.shape[0], 3, input_im.shape[2], input_im.shape[3])
  11. return out
  12. # output: output input:input
  13. def forward(self, input, output):
  14. self.output = output
  15. self.input = self.rgb2yCbCr(input)
  16. sigma_color = -1.0 / (2 * self.sigma * self.sigma)
  17. w1 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :] - self.input[:, :, :-1, :], 2), dim=1,
  18. keepdim=True) * sigma_color)
  19. w2 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :] - self.input[:, :, 1:, :], 2), dim=1,
  20. keepdim=True) * sigma_color)
  21. w3 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 1:] - self.input[:, :, :, :-1], 2), dim=1,
  22. keepdim=True) * sigma_color)
  23. w4 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-1] - self.input[:, :, :, 1:], 2), dim=1,
  24. keepdim=True) * sigma_color)
  25. w5 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-1] - self.input[:, :, 1:, 1:], 2), dim=1,
  26. keepdim=True) * sigma_color)
  27. w6 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 1:] - self.input[:, :, :-1, :-1], 2), dim=1,
  28. keepdim=True) * sigma_color)
  29. w7 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-1] - self.input[:, :, :-1, 1:], 2), dim=1,
  30. keepdim=True) * sigma_color)
  31. w8 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 1:] - self.input[:, :, 1:, :-1], 2), dim=1,
  32. keepdim=True) * sigma_color)
  33. w9 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :] - self.input[:, :, :-2, :], 2), dim=1,
  34. keepdim=True) * sigma_color)
  35. w10 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :] - self.input[:, :, 2:, :], 2), dim=1,
  36. keepdim=True) * sigma_color)
  37. w11 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, 2:] - self.input[:, :, :, :-2], 2), dim=1,
  38. keepdim=True) * sigma_color)
  39. w12 = torch.exp(torch.sum(torch.pow(self.input[:, :, :, :-2] - self.input[:, :, :, 2:], 2), dim=1,
  40. keepdim=True) * sigma_color)
  41. w13 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-1] - self.input[:, :, 2:, 1:], 2), dim=1,
  42. keepdim=True) * sigma_color)
  43. w14 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 1:] - self.input[:, :, :-2, :-1], 2), dim=1,
  44. keepdim=True) * sigma_color)
  45. w15 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-1] - self.input[:, :, :-2, 1:], 2), dim=1,
  46. keepdim=True) * sigma_color)
  47. w16 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 1:] - self.input[:, :, 2:, :-1], 2), dim=1,
  48. keepdim=True) * sigma_color)
  49. w17 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, :-2] - self.input[:, :, 1:, 2:], 2), dim=1,
  50. keepdim=True) * sigma_color)
  51. w18 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, 2:] - self.input[:, :, :-1, :-2], 2), dim=1,
  52. keepdim=True) * sigma_color)
  53. w19 = torch.exp(torch.sum(torch.pow(self.input[:, :, 1:, :-2] - self.input[:, :, :-1, 2:], 2), dim=1,
  54. keepdim=True) * sigma_color)
  55. w20 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-1, 2:] - self.input[:, :, 1:, :-2], 2), dim=1,
  56. keepdim=True) * sigma_color)
  57. w21 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, :-2] - self.input[:, :, 2:, 2:], 2), dim=1,
  58. keepdim=True) * sigma_color)
  59. w22 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, 2:] - self.input[:, :, :-2, :-2], 2), dim=1,
  60. keepdim=True) * sigma_color)
  61. w23 = torch.exp(torch.sum(torch.pow(self.input[:, :, 2:, :-2] - self.input[:, :, :-2, 2:], 2), dim=1,
  62. keepdim=True) * sigma_color)
  63. w24 = torch.exp(torch.sum(torch.pow(self.input[:, :, :-2, 2:] - self.input[:, :, 2:, :-2], 2), dim=1,
  64. keepdim=True) * sigma_color)
  65. p = 1.0
  66. pixel_grad1 = w1 * torch.norm((self.output[:, :, 1:, :] - self.output[:, :, :-1, :]), p, dim=1, keepdim=True)
  67. pixel_grad2 = w2 * torch.norm((self.output[:, :, :-1, :] - self.output[:, :, 1:, :]), p, dim=1, keepdim=True)
  68. pixel_grad3 = w3 * torch.norm((self.output[:, :, :, 1:] - self.output[:, :, :, :-1]), p, dim=1, keepdim=True)
  69. pixel_grad4 = w4 * torch.norm((self.output[:, :, :, :-1] - self.output[:, :, :, 1:]), p, dim=1, keepdim=True)
  70. pixel_grad5 = w5 * torch.norm((self.output[:, :, :-1, :-1] - self.output[:, :, 1:, 1:]), p, dim=1, keepdim=True)
  71. pixel_grad6 = w6 * torch.norm((self.output[:, :, 1:, 1:] - self.output[:, :, :-1, :-1]), p, dim=1, keepdim=True)
  72. pixel_grad7 = w7 * torch.norm((self.output[:, :, 1:, :-1] - self.output[:, :, :-1, 1:]), p, dim=1, keepdim=True)
  73. pixel_grad8 = w8 * torch.norm((self.output[:, :, :-1, 1:] - self.output[:, :, 1:, :-1]), p, dim=1, keepdim=True)
  74. pixel_grad9 = w9 * torch.norm((self.output[:, :, 2:, :] - self.output[:, :, :-2, :]), p, dim=1, keepdim=True)
  75. pixel_grad10 = w10 * torch.norm((self.output[:, :, :-2, :] - self.output[:, :, 2:, :]), p, dim=1, keepdim=True)
  76. pixel_grad11 = w11 * torch.norm((self.output[:, :, :, 2:] - self.output[:, :, :, :-2]), p, dim=1, keepdim=True)
  77. pixel_grad12 = w12 * torch.norm((self.output[:, :, :, :-2] - self.output[:, :, :, 2:]), p, dim=1, keepdim=True)
  78. pixel_grad13 = w13 * torch.norm((self.output[:, :, :-2, :-1] - self.output[:, :, 2:, 1:]), p, dim=1, keepdim=True)
  79. pixel_grad14 = w14 * torch.norm((self.output[:, :, 2:, 1:] - self.output[:, :, :-2, :-1]), p, dim=1, keepdim=True)
  80. pixel_grad15 = w15 * torch.norm((self.output[:, :, 2:, :-1] - self.output[:, :, :-2, 1:]), p, dim=1, keepdim=True)
  81. pixel_grad16 = w16 * torch.norm((self.output[:, :, :-2, 1:] - self.output[:, :, 2:, :-1]), p, dim=1, keepdim=True)
  82. pixel_grad17 = w17 * torch.norm((self.output[:, :, :-1, :-2] - self.output[:, :, 1:, 2:]), p, dim=1, keepdim=True)
  83. pixel_grad18 = w18 * torch.norm((self.output[:, :, 1:, 2:] - self.output[:, :, :-1, :-2]), p, dim=1, keepdim=True)
  84. pixel_grad19 = w19 * torch.norm((self.output[:, :, 1:, :-2] - self.output[:, :, :-1, 2:]), p, dim=1, keepdim=True)
  85. pixel_grad20 = w20 * torch.norm((self.output[:, :, :-1, 2:] - self.output[:, :, 1:, :-2]), p, dim=1, keepdim=True)
  86. pixel_grad21 = w21 * torch.norm((self.output[:, :, :-2, :-2] - self.output[:, :, 2:, 2:]), p, dim=1, keepdim=True)
  87. pixel_grad22 = w22 * torch.norm((self.output[:, :, 2:, 2:] - self.output[:, :, :-2, :-2]), p, dim=1, keepdim=True)
  88. pixel_grad23 = w23 * torch.norm((self.output[:, :, 2:, :-2] - self.output[:, :, :-2, 2:]), p, dim=1, keepdim=True)
  89. pixel_grad24 = w24 * torch.norm((self.output[:, :, :-2, 2:] - self.output[:, :, 2:, :-2]), p, dim=1, keepdim=True)
  90. ReguTerm1 = torch.mean(pixel_grad1) \
  91. + torch.mean(pixel_grad2) \
  92. + torch.mean(pixel_grad3) \
  93. + torch.mean(pixel_grad4) \
  94. + torch.mean(pixel_grad5) \
  95. + torch.mean(pixel_grad6) \
  96. + torch.mean(pixel_grad7) \
  97. + torch.mean(pixel_grad8) \
  98. + torch.mean(pixel_grad9) \
  99. + torch.mean(pixel_grad10) \
  100. + torch.mean(pixel_grad11) \
  101. + torch.mean(pixel_grad12) \
  102. + torch.mean(pixel_grad13) \
  103. + torch.mean(pixel_grad14) \
  104. + torch.mean(pixel_grad15) \
  105. + torch.mean(pixel_grad16) \
  106. + torch.mean(pixel_grad17) \
  107. + torch.mean(pixel_grad18) \
  108. + torch.mean(pixel_grad19) \
  109. + torch.mean(pixel_grad20) \
  110. + torch.mean(pixel_grad21) \
  111. + torch.mean(pixel_grad22) \
  112. + torch.mean(pixel_grad23) \
  113. + torch.mean(pixel_grad24)
  114. total_term = ReguTerm1
  115. return total_term

 小结

        论文最牛逼的就是无监督,效果还不错,用级联的,权重共享的方式来估计光照,构造loss来做到无监督,无监督的loss的基本假设就是如果已经没有光照的影响了,光照估计就估计不出来光照了,就应该稳定了。从Retinex theory的理论也知道,如果用网络估计出来x,那么干净图z就可以得到,那么确实不需要label,就可以做到无监督。第一级把光照x估计出来了,那么z就得到了,那么再对z进行光照估计(第二级),此时,我们知道z是没有光照的,或者说z的光照x就是1,这时,其实相当于有label了。

总结

        无监督是大家追求的目标,这样就不要花大量的财力来造图像label,目前实际使用的多数都是找人来修图,修成某种艺术风格,但这种代价很大。当然,我估计这些无监督的方法,目前也没法用在手机拍照上,效果还是没法保证没问题,一旦出现偏色的场景,大概率就接受不了。但我认为,这些无监督的方法可以用在面向机器的拍照上,作为一些前处理,为后端的机器识别来服务,比如人脸识别,人像分割等等。

 

 

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/451742
推荐阅读
相关标签
  

闽ICP备14008679号