赞
踩
这个Python脚本主要用于评估图像质量,它比较了一组高清(HD)图像和对应的生成图像,并计算了四种不同的图像质量指标:PSNR、SSIM、LPIPS和NIQE。
在代码开始,使用了LPIPS库来初始化一个预训练的VGG网络,这用于后续的LPIPS图像质量评估。
loss_fn = lpips.LPIPS(net='vgg')
这两个函数分别使用OpenCV和skimage库来计算PSNR和SSIM。这些都是全参考指标,需要原图和生成图进行比较。
这个函数使用初始化的LPIPS模型来评估两个图像(原图和生成图)之间的感知差异。
这个函数实现了NIQE(无参考图像质量评估),即只需要一个图像即可评估其质量。
这个函数是代码的核心,它执行以下操作:
cv2.imread
方法。使用Python的multiprocessing.Pool
来并行处理所有高清图像。这是一种典型的"Map-Reduce"模式,其中process_image
函数是map操作。
- with Pool(4) as pool: # Initialize a pool with 4 processes
- pool.starmap(process_image, [(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder) for i in range(1, 570)])
- import os
- from multiprocessing import Pool
-
- import cv2
- import lpips
- import numpy as np
- import torch
- from scipy.ndimage import filters
- from scipy.special import gammaln
- from scipy.stats import genpareto
- from skimage import img_as_float
- from skimage.metrics import structural_similarity as compare_ssim
- from tqdm import tqdm
-
- # Initialize LPIPS
- loss_fn = lpips.LPIPS(net='vgg')
-
-
- def calculate_psnr(img1, img2):
- return cv2.PSNR(img1, img2)
-
-
- def calculate_ssim(img1, img2):
- return compare_ssim(img1, img2, multichannel=True)
-
-
- def calculate_lpips(img1, img2):
- img1 = torch.Tensor(img1).permute(2, 0, 1).unsqueeze(0)
- img2 = torch.Tensor(img2).permute(2, 0, 1).unsqueeze(0)
- return loss_fn(img1, img2).item()
-
-
- def calculate_niqe(image):
- image = img_as_float(image)
- h, w = image.shape[:2]
- block_size = 96
- strides = 32
- features = []
-
- for i in range(0, h - block_size + 1, strides):
- for j in range(0, w - block_size + 1, strides):
- block = image[i:i + block_size, j:j + block_size]
- mu = np.mean(block)
- sigma = np.std(block)
- filtered_block = filters.gaussian_filter(block, sigma)
- shape, _, scale = genpareto.fit(filtered_block.ravel(), floc=0)
- feature = [mu, sigma, shape, scale, gammaln(1 / shape)]
- features.append(feature)
-
- features = np.array(features)
- model_mean = np.zeros(features.shape[1])
- model_cov_inv = np.eye(features.shape[1])
- quality_scores = []
-
- for feature in features:
- score = (feature - model_mean) @ model_cov_inv @ (feature - model_mean).T
- quality_scores.append(score)
-
- return np.mean(quality_scores)
-
-
- def process_image(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder):
- hd_img_name = f"{i}.png"
- hd_img_path = os.path.join(hd_img_folder, hd_img_name)
- hd_img = cv2.imread(hd_img_path)
-
- corresponding_generated_folder = os.path.join(generated_img_root_folder, str(i))
- if not os.path.exists(corresponding_generated_folder):
- print(f"Folder for {hd_img_name} does not exist. Skipping.")
- return
-
- output_file_path = os.path.join(output_root_folder, f"{i}_output.txt")
- generated_img_names = os.listdir(corresponding_generated_folder)
- generated_img_names.sort(key=lambda x: int(x.split('.')[0]))
- total_images = len(generated_img_names)
-
- best_psnr = 0
- best_ssim = 0
- best_lpips = float('inf')
- best_niqe = float('inf')
- best_metrics_record = {}
-
- with open(output_file_path, 'w') as f:
- f.write(f"Results for HD Image: {hd_img_name}\n")
- f.write("-------------------------------------\n")
-
- for generated_img_name in tqdm(generated_img_names, total=total_images, desc=f"Processing for {hd_img_name}",
- leave=False):
- generated_img_path = os.path.join(corresponding_generated_folder, generated_img_name)
- generated_img = cv2.imread(generated_img_path)
-
- psnr = calculate_psnr(hd_img, generated_img)
- ssim = calculate_ssim(hd_img, generated_img)
- lpips_value = calculate_lpips(hd_img, generated_img)
- niqe = calculate_niqe(generated_img)
-
- result_line = f"{generated_img_name} PSNR: {psnr} SSIM: {ssim} LPIPS: {lpips_value} NIQE: {niqe}\n"
- f.write(result_line)
-
- if psnr > best_psnr:
- best_psnr = psnr
- best_metrics_record['Best PSNR'] = (generated_img_name, best_psnr)
- if ssim > best_ssim:
- best_ssim = ssim
- best_metrics_record['Best SSIM'] = (generated_img_name, best_ssim)
- if lpips_value < best_lpips:
- best_lpips = lpips_value
- best_metrics_record['Best LPIPS'] = (generated_img_name, best_lpips)
- if niqe < best_niqe:
- best_niqe = niqe
- best_metrics_record['Best NIQE'] = (generated_img_name, best_niqe)
-
- with open(main_output_file_path, 'a') as main_f:
- main_f.write(f"Best Metrics for {hd_img_name}\n")
- main_f.write("-------------------------------------\n")
- for metric, (img_name, value) in best_metrics_record.items():
- main_f.write(f"{metric}: {img_name}, Value: {value}\n")
- main_f.write("\n")
- print(f"Best Metrics for {hd_img_name} are saved in {main_output_file_path}")
-
-
- if __name__ == "__main__":
- hd_img_folder = 'xxxxxxx'
- generated_img_root_folder = 'xxxxxxx'
- output_root_folder = 'xxxxxxx'
-
- main_output_file_path = os.path.join(output_root_folder, "all_results.txt")
-
- with open(main_output_file_path, 'w') as main_f:
- main_f.write("Detailed Results for Each High-Definition Image\n")
- main_f.write("==============================================\n")
-
- with Pool(4) as pool: # Initialize a pool with 8 processes
- pool.starmap(process_image,
- [(i, main_output_file_path, hd_img_folder, generated_img_root_folder, output_root_folder) for i in
- range(1, 570)]) # Parallel processing
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。