赞
踩
1.文章:(CVPR 2020) Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement
2. 链接: paper.
3. 链接: code.
4. 其他博主复现链接: link.
1.检查电脑是否含有GPU,不是GPU环境需要将代码改为CPU环境下
2. 源代码中没有数据集
3. 代码中路径问题,路径不对将导致无法加载数据
3.需要按照要求设置文件夹
4. 由于版本不同,代码运行中出现的警告影响代码运行
论文中的代码是GPU环境下的,并且是在每一句需要用到GPU的代码下注释的,所以需要将所有含有该语令的改为CPU
具体改法:
将代码中".cuda()"删去,或者将其改为“.cpu()”
下面仅仅是部分示例:其中前面带“#”为源代码
// Myloss.py
# kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
# kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
# kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
# kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
kernel_left = torch.FloatTensor([[0, 0, 0], [-1, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
kernel_right = torch.FloatTensor([[0, 0, 0], [0, 1, -1], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
kernel_up = torch.FloatTensor([[0, -1, 0], [0, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
kernel_down = torch.FloatTensor([[0, 0, 0], [0, 1, 0], [0, -1, 0]]).unsqueeze(0).unsqueeze(0)
# weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
weight_diff =torch.max(torch.FloatTensor([1]) + 10000*torch.min(org_pool - torch.FloatTensor([0.3]),torch.FloatTensor([0])),torch.FloatTensor([0.5]))
# E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5])), enhance_pool - org_pool)
在CPU条件下 在训练模型中需要将pin_memory=True改为pin_memory=False
// lowlight_train.py
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True,num_workers=config.num_workers, pin_memory=False)
该论文源代码里面有预训练模型,所以没有训练数据集,并不影响直接测试代码。
该Epoch99.pth即为预训练模型。所以训练集并不影响测试。
当然你也可以根据代码要求寻找数据集加入,并训练出自己的预训练模型
在dataloader.py文件中 :
作用是为数据做预处理的相关代码。
文件路径出现了“.jpg”,如果训练集中的数据为PNG格式,就需要将最后改为“.png”
// dataloader.py
def populate_train_list(lowlight_images_path):#获取训练列表(微光图像路径)
# image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg")
image_list_lowlight = glob.glob(lowlight_images_path + "*.png")
train_list = image_list_lowlight
random.shuffle(train_list)
return train_list
在 lowlight_train.py文件中 :
将文件路径改为自己的数据路径
绝对路径,相对路径都可以
“在开始调试的代码过程中可以更加学习率与减少迭代次数,在代码调通后在改为规定的参数,可以运行更快”
// lowlight_train.py if __name__ == "__main__": parser = argparse.ArgumentParser() # Input Parameters # parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/") parser.add_argument('--lowlight_images_path', type=str, default="E:/image/Zero-DCE-master/Zero-DCE_code/data/train_data/") # parser.add_argument('--lr', type=float, default=0.0001) # parser.add_argument('--weight_decay', type=float, default=0.0001) parser.add_argument('--grad_clip_norm', type=float, default=0.1) parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--num_epochs', type=int, default=20) # parser.add_argument('--num_epochs', type=int, default=20) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--val_batch_size', type=int, default=4) parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--display_iter', type=int, default=10) parser.add_argument('--snapshot_iter', type=int, default=10) parser.add_argument('--snapshots_folder', type=str, default="E:/image/Zero-DCE-master/Zero-DCE_code/snapshots/") # parser.add_argument('--snapshots_folder', type=str, default="snapshots/") parser.add_argument('--load_pretrain', type=bool, default= False) parser.add_argument('--pretrain_dir', type=str, default= "E:/image/Zero-DCE-master/Zero-DCE_code/snapshots/Epoch99.pth") # parser.add_argument('--pretrain_dir', type=str, default="snapshots/Epoch99.pth") config = parser.parse_args()
在data文件夹下设置result文件夹,并在result文件下设置与test_data文件夹下一样的两个文件夹:DICM与LIME文件夹
必须一样,否则将报错
一、 torch.nn.utils.clip_grad_norm函数被弃用
// torch.nn.utils.clip_grad_norm函数被弃用 的警告
UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_.
torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
改法;
// lowlight_train.py
# torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
torch.nn.utils.clip_grad_norm_(DCE_net.parameters(),config.grad_clip_norm)
二、 nn.functional.tanh被弃用
// nn.functional.tanh被弃用的警告
UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
改法
// model.py
x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
# x5 = self.upsample(x5)
x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
# x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1))) #改之前
x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。