当前位置:   article > 正文

批量测试图像的LPIPS、SSIM、PSNR指标,并生成CSV文件结果_python lpips 评价指标批量评价

python lpips 评价指标批量评价

目录

背景:

代码:

结果


背景:

由于论文需要,我将测试代码进行了修改,融入了LPIPS等指标,关于LPIPS介绍,可以查看博客:

LPIPS图像相似性度量标准:The Unreasonable Effectiveness of Deep Features as a Perceptual Metric_Alocus的博客-CSDN博客_lpips度量

代码:

代码中有两大部分需要进行修改,我在代码中进行了标注。

需要安装一些包,如pip install lpips

有疑问欢迎联系交流。

  1. # -*- coding:utf-8 _*-
  2. import argparse
  3. import os
  4. import cv2
  5. import pandas as pd
  6. import lpips
  7. import torch
  8. import torchvision.transforms as transforms
  9. from skimage.metrics import peak_signal_noise_ratio
  10. from skimage.metrics import structural_similarity
  11. """
  12. author&wechat:Alocus
  13. QQ:1913434222
  14. data:2022.02.27
  15. """
  16. #########################需要修改的部分(1)#######################
  17. #abs_path = r'C:\Users\Administrator\Desktop\test\metric' #绝对路径,存放csv文件位置
  18. abs_path = os.getcwd()
  19. test_set_name = r'Result' #用于生成csv结果文件名字
  20. #result_dir = r'C:\Users\Administrator\Desktop\test\metric\haze' #存放结果的文件夹路径及名字
  21. result_dir = os.path.join(abs_path,'EPDN_result' )
  22. #GT_dir = r'C:\Users\Administrator\Desktop\test\metric\GT' #存放GT的文件夹路径及名字
  23. GT_dir = os.path.join(abs_path,'test-label' )
  24. split = '_' #从result中分割出GT名字的符号,如10222_01_0.8411.png 中第一个符号为_,分割后GT为10222.png。 如何实际二者相同,请选择GT中没有的符号。
  25. ###########################################################
  26. #存储列表
  27. result_images =[] #结果图片名字列表
  28. GT_image = [] #GT图片名字列表
  29. image_number = [] #图片读取求指标的id
  30. image_name = [] #图片读取求指标的id 对应的图片名
  31. psnr_number = [] #psnr值列表
  32. ssim_number = [] #ssim值列表
  33. lpips_number = [] #lpips值列表
  34. for root, _, fnames in sorted(os.walk(result_dir)):
  35. for fname in fnames:
  36. #path = os.path.join(root, fname)
  37. result_images.append(fname) ##结果图片名字列表
  38. for root, _, fnames in sorted(os.walk(GT_dir)):
  39. for fname in fnames:
  40. #path = os.path.join(root, fname)
  41. GT_image.append(fname) #GT图片名字列表
  42. for i in range(len(result_images)):
  43. print("now"+str(i))
  44. name_extension = result_images[i] #result图片全名
  45. (name, extension) = os.path.splitext(name_extension) #分离名字和后缀
  46. ###########拼接出对应的GT中的全名###需要修改的部分(2)#########################
  47. index = name.find(split) #根据编名方式,提取出GT部分名字结束下标
  48. if index == -1: #没有找到。则直接按照结果的名字拼接成GT名字
  49. index = len(name)
  50. find_name = name[:index] #提取出在GT中对应的图片名
  51. GT_name = 'label'+find_name[7:11]+".png" #拼接结果
  52. result_name_path =os.path.join(result_dir,name_extension)#拼接result 路径和名字
  53. GT_name_path =os.path.join(GT_dir,GT_name) #拼接GT路径和名字
  54. # print(result_name_path)
  55. # print(GT_name_path)
  56. #开始计算指标
  57. result = cv2.imread(result_name_path)
  58. GT = cv2.imread(GT_name_path)
  59. psnr = peak_signal_noise_ratio(GT,result)
  60. ssim= structural_similarity(GT,result, multichannel=True)
  61. loss_fn_alex = lpips.LPIPS(net='alex',version=0.1) # best forward scores
  62. loss_fn_vgg = lpips.LPIPS(net='vgg',version=0.1) # closer to "traditional" perceptual loss, when used for optimization
  63. test1_res = result
  64. test1_label = GT
  65. transf = transforms.ToTensor()
  66. test1_label = transf(test1_label)
  67. test1_res = transf(test1_res)
  68. test1_ress = test1_res.to(torch.float32).unsqueeze(0)
  69. test1_labell = test1_label.to(torch.float32).unsqueeze(0)
  70. lpips_loss = loss_fn_alex(test1_ress, test1_labell)
  71. #计算结果存入列表
  72. image_number.append(str(i))
  73. image_name.append(name)
  74. psnr_number.append(psnr)
  75. ssim_number.append(ssim)
  76. lpips_number.append(lpips_loss)
  77. #计算列表中指标值的平均值函数
  78. def ave(lis):
  79. s = 0
  80. total_num = len(lis)
  81. for i in lis:
  82. s = s + i
  83. return s/total_num
  84. #计算列表中指标值的平均值,并加入列表
  85. total = 'total(' + str(len(image_number)) + ')'
  86. image_number.append(total)
  87. image_name.append('average')
  88. psnr_ave = ave(psnr_number)
  89. ssim_ave = ave(ssim_number)
  90. lpips_ave = ave(lpips_number)
  91. psnr_number.append(psnr_ave)
  92. ssim_number.append(ssim_ave)
  93. lpips_number.append(lpips_ave)
  94. #生成csv文件
  95. dit = {'image_number':image_number, 'result_name':image_name, 'psnr':psnr_number,'ssim':ssim_number,'lpips':lpips_number}
  96. df = pd.DataFrame(dit)
  97. csv_name = test_set_name + '_ssim&psnr&lpips.csv' #拼接csv名字
  98. csv_path = os.path.join(abs_path,csv_name) #csv全路径
  99. df.to_csv(csv_path,columns=['image_number','result_name','psnr','ssim','lpips'],index=False,sep=',')
  100. print('————————————————————————————finish————————————————————————————')
  101. print('csv_save_path:',csv_path) ##csv全路径
  102. print('result_photos_num:',len(result_images)) #result_photos_num
  103. print('GT_photos_num:',len(GT_image)) #GT_photos_num
  104. print('psnr_ave:',psnr_ave) #psnr_ave
  105. print('ssim_ave:',ssim_ave) #ssim_ave
  106. print('lpips_ave:',lpips_ave) #ssim_ave

结果:

注:有小伙伴测试时还是出现问题 ,我打包了个测试文件,图片可以按其命名放入对应文件夹,或者修改对应代码。

链接:https://pan.baidu.com/s/1xcScdzQofvQPkuVPy1FAVw?pwd=LUCK 
提取码:LUCK

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/192910
推荐阅读
相关标签
  

闽ICP备14008679号