当前位置:   article > 正文

视频质量评价python,包括psnr,ssim,LPIPS_lpips视频质量评估

lpips视频质量评估

文件准备:

创建两个文件夹,分别叫做video_gen,video_real。

video_gen下放置生成的视频的帧,以1.jpg样式命名。

video_real下放置真值,命名方式同上。

评价代码:

  1. """
  2. Video Quality Metrics
  3. Copyright (c) 2014 Alex Izvorski <aizvorski@gmail.com>
  4. This program is free software: you can redistribute it and/or modify
  5. it under the terms of the GNU General Public License as published by
  6. the Free Software Foundation, either version 3 of the License, or
  7. (at your option) any later version.
  8. This program is distributed in the hope that it will be useful,
  9. but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. GNU General Public License for more details.
  12. You should have received a copy of the GNU General Public License
  13. along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. """
  15. import os
  16. # time:2023.10.14 9:12
  17. # author: Yuanping
  18. # 注意修改main中的图片格式,如jpg还是png等
  19. # 评估很多帧图片psnr并计算平局值的函数,只需改变gen_path和real_path就行,或者将相应jpg格式图片放入相应文件夹即可
  20. # 最好都是相同命名,且gen的数量得小于等于real
  21. # 会自动将评估得到的值存入csv,包括每一帧的和平均值
  22. # 不需要计算某个值可以在video_quality函数中把计算某个值的过程注释掉
  23. from glob import glob
  24. import lpips
  25. from skimage.metrics import structural_similarity as ssim
  26. import cv2
  27. import numpy
  28. import math
  29. import numpy as np
  30. import pandas as pd
  31. def psnr(img1, img2):
  32. mse = numpy.mean((img1 - img2) ** 2)
  33. if mse == 0:
  34. print("error:mse=0")
  35. exit(1)
  36. # return 0
  37. PIXEL_MAX = 255.0
  38. return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
  39. def cal_psnr(gen_path, real_path):
  40. psnr_list = []
  41. for i in range(len(gen_path)):
  42. gen_pic = cv2.imread(gen_path[i])
  43. real_pic = cv2.imread(real_path[i])
  44. print("frame picture:")
  45. psnr_value = psnr(gen_pic, real_pic)
  46. print(i, " frame psnr: ", psnr_value)
  47. psnr_list.append([i, psnr_value])
  48. psnr_csv = pd.DataFrame(psnr_list, columns=['frame', 'psnr_value'])
  49. # psnr_csv = psnr_csv._append({'frame': 'average', 'psnr_value': psnr_csv["psnr_value"].mean()}, ignore_index=True)
  50. # psnr_csv.to_csv('video_psnr.csv', index=False)
  51. return psnr_csv
  52. def cal_ssim(gen_path, real_path, quality_df):
  53. ssim_list = []
  54. for i in range(len(gen_path)):
  55. gen_pic = cv2.imread(gen_path[i])
  56. real_pic = cv2.imread(real_path[i])
  57. ssim_value = ssim(gen_pic, real_pic, channel_axis=-1)
  58. print(i, " frame ssim: ", ssim_value)
  59. ssim_list.append(ssim_value)
  60. quality_df['ssim_value'] = ssim_list
  61. return quality_df
  62. def cal_LPIPS(gen_path, real_path, quality_df):
  63. """
  64. 参考代码:https://blog.csdn.net/weixin_43135178/article/details/127664187
  65. """
  66. ## Initializing the model
  67. loss_fn = lpips.LPIPS(net='alex', version=0.1) # pip install lpips
  68. lpips_list = []
  69. for i in range(len(gen_path)):
  70. try:
  71. # Load images
  72. img0 = lpips.im2tensor(lpips.load_image(gen_path[i]))
  73. img1 = lpips.im2tensor(lpips.load_image(real_path[i]))
  74. # Compute distance,之所以会有后面detach和numpy是因为要把梯度去掉,这样才能加到dataframe里面
  75. # 之所以还有个mean函数,是因为不加mean的话得到的是[[[value]]],加了mean可以直接得到值
  76. current_lpips_distance = loss_fn.forward(img0, img1).detach().numpy().mean()
  77. print(i, " frame lpips: ", current_lpips_distance)
  78. lpips_list.append(current_lpips_distance)
  79. except Exception as e:
  80. print(e)
  81. quality_df['lpips_value'] = lpips_list
  82. return quality_df
  83. def video_quality(gen_path, real_path):
  84. quality_df = cal_psnr(gen_path, real_path)
  85. quality_df = cal_ssim(gen_path, real_path, quality_df)
  86. quality_df = cal_LPIPS(gen_path, real_path, quality_df)
  87. quality_df = quality_df._append({'frame': 'average',
  88. 'ssim_value': quality_df["ssim_value"].mean(),
  89. 'psnr_value': quality_df["psnr_value"].mean(),
  90. 'lpips_value': quality_df["lpips_value"].mean()},
  91. ignore_index=True)
  92. quality_df.to_csv('video_quality.csv', index=False)
  93. if __name__ == '__main__':
  94. gen_path = glob("video_gen/*.jpg")
  95. real_path = glob("video_real/*.jpg")
  96. video_quality(gen_path, real_path)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/192915
推荐阅读
相关标签
  

闽ICP备14008679号