当前位置:   article > 正文

超分辨率指标计算:Python代码用于评估图像质量,PSNR、SSIM、LPIPS和NIQE_opencv无参考图像质量评价

opencv无参考图像质量评价

整体目标与设计思想

整体目标

这个Python脚本主要用于评估图像质量,它比较了一组高清(HD)图像和对应的生成图像,并计算了四种不同的图像质量指标:PSNR、SSIM、LPIPS和NIQE。

设计思想

  1. 模块化: 代码通过函数进行模块化,每个函数负责一个特定任务,如计算PSNR或SSIM。
  2. 并行化: 使用多进程来加速图像处理,特别是在处理大量图像时。
  3. 可扩展性: 可以容易地添加更多的图像质量指标。
  4. 结果记录: 为每个高清图像生成一个详细的文本文件,记录与之相关的生成图像的所有质量指标。

功能模块解析

初始化LPIPS模型

在代码开始,使用了LPIPS库来初始化一个预训练的VGG网络,这用于后续的LPIPS图像质量评估。

loss_fn = lpips.LPIPS(net='vgg')

图像质量评估函数

calculate_psnr, calculate_ssim

这两个函数分别使用OpenCV和skimage库来计算PSNR和SSIM。这些都是全参考指标,需要原图和生成图进行比较。

calculate_lpips

这个函数使用初始化的LPIPS模型来评估两个图像(原图和生成图)之间的感知差异。

calculate_niqe

这个函数实现了NIQE(无参考图像质量评估),即只需要一个图像即可评估其质量。

主要的图像处理函数:process_image()

这个函数是代码的核心,它执行以下操作:

  1. 读取高清图像和生成图像: 使用OpenCV的cv2.imread方法。
  2. 初始化质量指标: 初始化用于存储最佳质量指标的变量。
  3. 计算并记录指标: 对于每个生成图像,计算所有的质量指标,并记录最佳值。
  4. 结果存储: 将计算的质量指标保存到一个文本文件。

并行处理

使用Python的multiprocessing.Pool来并行处理所有高清图像。这是一种典型的"Map-Reduce"模式,其中process_image函数是map操作。

  1. with Pool(4) as pool: # Initialize a pool with 4 processes
  2. pool.starmap(process_image, [(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder) for i in range(1, 570)])

性能和优化

  1. IO操作: 由于代码涉及大量的文件读写(图像和文本),IO可能成为性能瓶颈。
  2. 并行优化: 使用多进程可以显著加速处理时间,但要注意不要超过系统的CPU核心数。

最佳实践与改进建议

  1. 配置文件或命令行参数: 硬编码(如路径和进程数)应避免,最好使用配置文件或命令行参数。
  2. 错误处理: 当前代码没有太多的错误处理,建议添加更多的异常捕获和错误检查。

代码

  1. import os
  2. from multiprocessing import Pool
  3. import cv2
  4. import lpips
  5. import numpy as np
  6. import torch
  7. from scipy.ndimage import filters
  8. from scipy.special import gammaln
  9. from scipy.stats import genpareto
  10. from skimage import img_as_float
  11. from skimage.metrics import structural_similarity as compare_ssim
  12. from tqdm import tqdm
  13. # Initialize LPIPS
  14. loss_fn = lpips.LPIPS(net='vgg')
  15. def calculate_psnr(img1, img2):
  16. return cv2.PSNR(img1, img2)
  17. def calculate_ssim(img1, img2):
  18. return compare_ssim(img1, img2, multichannel=True)
  19. def calculate_lpips(img1, img2):
  20. img1 = torch.Tensor(img1).permute(2, 0, 1).unsqueeze(0)
  21. img2 = torch.Tensor(img2).permute(2, 0, 1).unsqueeze(0)
  22. return loss_fn(img1, img2).item()
  23. def calculate_niqe(image):
  24. image = img_as_float(image)
  25. h, w = image.shape[:2]
  26. block_size = 96
  27. strides = 32
  28. features = []
  29. for i in range(0, h - block_size + 1, strides):
  30. for j in range(0, w - block_size + 1, strides):
  31. block = image[i:i + block_size, j:j + block_size]
  32. mu = np.mean(block)
  33. sigma = np.std(block)
  34. filtered_block = filters.gaussian_filter(block, sigma)
  35. shape, _, scale = genpareto.fit(filtered_block.ravel(), floc=0)
  36. feature = [mu, sigma, shape, scale, gammaln(1 / shape)]
  37. features.append(feature)
  38. features = np.array(features)
  39. model_mean = np.zeros(features.shape[1])
  40. model_cov_inv = np.eye(features.shape[1])
  41. quality_scores = []
  42. for feature in features:
  43. score = (feature - model_mean) @ model_cov_inv @ (feature - model_mean).T
  44. quality_scores.append(score)
  45. return np.mean(quality_scores)
  46. def process_image(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder):
  47. hd_img_name = f"{i}.png"
  48. hd_img_path = os.path.join(hd_img_folder, hd_img_name)
  49. hd_img = cv2.imread(hd_img_path)
  50. corresponding_generated_folder = os.path.join(generated_img_root_folder, str(i))
  51. if not os.path.exists(corresponding_generated_folder):
  52. print(f"Folder for {hd_img_name} does not exist. Skipping.")
  53. return
  54. output_file_path = os.path.join(output_root_folder, f"{i}_output.txt")
  55. generated_img_names = os.listdir(corresponding_generated_folder)
  56. generated_img_names.sort(key=lambda x: int(x.split('.')[0]))
  57. total_images = len(generated_img_names)
  58. best_psnr = 0
  59. best_ssim = 0
  60. best_lpips = float('inf')
  61. best_niqe = float('inf')
  62. best_metrics_record = {}
  63. with open(output_file_path, 'w') as f:
  64. f.write(f"Results for HD Image: {hd_img_name}\n")
  65. f.write("-------------------------------------\n")
  66. for generated_img_name in tqdm(generated_img_names, total=total_images, desc=f"Processing for {hd_img_name}",
  67. leave=False):
  68. generated_img_path = os.path.join(corresponding_generated_folder, generated_img_name)
  69. generated_img = cv2.imread(generated_img_path)
  70. psnr = calculate_psnr(hd_img, generated_img)
  71. ssim = calculate_ssim(hd_img, generated_img)
  72. lpips_value = calculate_lpips(hd_img, generated_img)
  73. niqe = calculate_niqe(generated_img)
  74. result_line = f"{generated_img_name} PSNR: {psnr} SSIM: {ssim} LPIPS: {lpips_value} NIQE: {niqe}\n"
  75. f.write(result_line)
  76. if psnr > best_psnr:
  77. best_psnr = psnr
  78. best_metrics_record['Best PSNR'] = (generated_img_name, best_psnr)
  79. if ssim > best_ssim:
  80. best_ssim = ssim
  81. best_metrics_record['Best SSIM'] = (generated_img_name, best_ssim)
  82. if lpips_value < best_lpips:
  83. best_lpips = lpips_value
  84. best_metrics_record['Best LPIPS'] = (generated_img_name, best_lpips)
  85. if niqe < best_niqe:
  86. best_niqe = niqe
  87. best_metrics_record['Best NIQE'] = (generated_img_name, best_niqe)
  88. with open(main_output_file_path, 'a') as main_f:
  89. main_f.write(f"Best Metrics for {hd_img_name}\n")
  90. main_f.write("-------------------------------------\n")
  91. for metric, (img_name, value) in best_metrics_record.items():
  92. main_f.write(f"{metric}: {img_name}, Value: {value}\n")
  93. main_f.write("\n")
  94. print(f"Best Metrics for {hd_img_name} are saved in {main_output_file_path}")
  95. if __name__ == "__main__":
  96. hd_img_folder = 'xxxxxxx'
  97. generated_img_root_folder = 'xxxxxxx'
  98. output_root_folder = 'xxxxxxx'
  99. main_output_file_path = os.path.join(output_root_folder, "all_results.txt")
  100. with open(main_output_file_path, 'w') as main_f:
  101. main_f.write("Detailed Results for Each High-Definition Image\n")
  102. main_f.write("==============================================\n")
  103. with Pool(4) as pool: # Initialize a pool with 8 processes
  104. pool.starmap(process_image,
  105. [(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder) for i in
  106. range(1, 570)]) # Parallel processing

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/192922
推荐阅读
相关标签
  

闽ICP备14008679号