当前位置:   article > 正文

【代码复现Zero-DCE详解:Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement】_zero-dce代码

zero-dce代码

在这里插入图片描述

链接概括

1.文章:(CVPR 2020) Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement
2. 链接: paper.
3. 链接: code.
4. 其他博主复现链接: link.

存在的几个主要问题

1.检查电脑是否含有GPU,不是GPU环境需要将代码改为CPU环境下
2. 源代码中没有数据集
3. 代码中路径问题,路径不对将导致无法加载数据
3.需要按照要求设置文件夹
4. 由于版本不同,代码运行中出现的警告影响代码运行

检查电脑是否含有GPU:将代码改为CPU环境下运行

论文中的代码是GPU环境下的,并且是在每一句需要用到GPU的代码下注释的,所以需要将所有含有该语令的改为CPU
具体改法:
将代码中".cuda()"删去,或者将其改为“.cpu()”
下面仅仅是部分示例:其中前面带“#”为源代码

// Myloss.py
        # kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        # kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        # kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        # kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_left = torch.FloatTensor([[0, 0, 0], [-1, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        kernel_right = torch.FloatTensor([[0, 0, 0], [0, 1, -1], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        kernel_up = torch.FloatTensor([[0, -1, 0], [0, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor([[0, 0, 0], [0, 1, 0], [0, -1, 0]]).unsqueeze(0).unsqueeze(0)

        # weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
        weight_diff =torch.max(torch.FloatTensor([1]) + 10000*torch.min(org_pool - torch.FloatTensor([0.3]),torch.FloatTensor([0])),torch.FloatTensor([0.5]))
        # E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5])), enhance_pool - org_pool)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

在CPU条件下 在训练模型中需要将pin_memory=True改为pin_memory=False

//  lowlight_train.py
   train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
	# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True,num_workers=config.num_workers, pin_memory=False)

  • 1
  • 2
  • 3
  • 4

源代码中没有训练数据集

该论文源代码里面有预训练模型,所以没有训练数据集,并不影响直接测试代码。
该Epoch99.pth即为预训练模型。所以训练集并不影响测试。
在这里插入图片描述
当然你也可以根据代码要求寻找数据集加入,并训练出自己的预训练模型

代码中路径问题,路径不对将导致无法加载数据

在dataloader.py文件中 :
作用是为数据做预处理的相关代码。
文件路径出现了“.jpg”,如果训练集中的数据为PNG格式,就需要将最后改为“.png”

// dataloader.py
def populate_train_list(lowlight_images_path):#获取训练列表(微光图像路径)

	# image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg")
	image_list_lowlight = glob.glob(lowlight_images_path + "*.png")

	train_list = image_list_lowlight

	random.shuffle(train_list)

	return train_list
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

在 lowlight_train.py文件中 :
将文件路径改为自己的数据路径
绝对路径,相对路径都可以
“在开始调试的代码过程中可以更加学习率与减少迭代次数,在代码调通后在改为规定的参数,可以运行更快”

// lowlight_train.py
if __name__ == "__main__":

	parser = argparse.ArgumentParser()

	# Input Parameters
	# parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/")
	parser.add_argument('--lowlight_images_path', type=str, default="E:/image/Zero-DCE-master/Zero-DCE_code/data/train_data/")
	# parser.add_argument('--lr', type=float, default=0.0001)
	# parser.add_argument('--weight_decay', type=float, default=0.0001)
	parser.add_argument('--grad_clip_norm', type=float, default=0.1)
	parser.add_argument('--lr', type=float, default=0.01)
	parser.add_argument('--weight_decay', type=float, default=0.01)
	parser.add_argument('--num_epochs', type=int, default=20)
	# parser.add_argument('--num_epochs', type=int, default=20)
	parser.add_argument('--train_batch_size', type=int, default=8)
	parser.add_argument('--val_batch_size', type=int, default=4)
	parser.add_argument('--num_workers', type=int, default=4)
	parser.add_argument('--display_iter', type=int, default=10)
	parser.add_argument('--snapshot_iter', type=int, default=10)
	parser.add_argument('--snapshots_folder', type=str, default="E:/image/Zero-DCE-master/Zero-DCE_code/snapshots/")
	# parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
	parser.add_argument('--load_pretrain', type=bool, default= False)
	parser.add_argument('--pretrain_dir', type=str, default= "E:/image/Zero-DCE-master/Zero-DCE_code/snapshots/Epoch99.pth")
	# parser.add_argument('--pretrain_dir', type=str, default="snapshots/Epoch99.pth")
	config = parser.parse_args()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

需要按照要求设置文件夹

在data文件夹下设置result文件夹,并在result文件下设置与test_data文件夹下一样的两个文件夹:DICM与LIME文件夹
必须一样,否则将报错
在这里插入图片描述

由于版本不同,代码运行中出现的警告影响代码运行

一、 torch.nn.utils.clip_grad_norm函数被弃用

//  torch.nn.utils.clip_grad_norm函数被弃用 的警告
UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 
torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
  • 1
  • 2
  • 3

改法;

//  lowlight_train.py
        # torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
   		torch.nn.utils.clip_grad_norm_(DCE_net.parameters(),config.grad_clip_norm)
  • 1
  • 2
  • 3

二、 nn.functional.tanh被弃用

//  nn.functional.tanh被弃用的警告
UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
 warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
  • 1
  • 2
  • 3

改法

//  model.py
    x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))

   	# x5 = self.upsample(x5)
   	x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

   	# x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))  #改之前
   	x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
   	r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/451744
推荐阅读
相关标签
  

闽ICP备14008679号