当前位置:   article > 正文

kornia库ssim使用,Invalid image shape_kornia.losses.ssim

kornia.losses.ssim

一、图像张量转四维

将图像张量转换为四维张量的原因-训练神经网络 - 知乎

  1. img1 = cv2.imread('OCT3.png') #读入一张图像
  2. img1 = cv2.resize(img1,(256,256)) #修改尺寸为256*256
  3. img1 = np.array(img1)
  4. tensor = transforms.ToTensor()
  5. img_tensor = tensor(img1) #转化成张量形式,此时的shape是(3,256,256)
  6. img_tensor1 = img_tensor.unsqueeze(0) #有时候需要四维张量运算,把三维变成四维,此时的shape是(1,3,256,256)

       将numpy类型转换为tensor类型:[h, w, c]->[c, h, w]

       还要将tensor类型转换成四维张量:[c, h, w]->[b, c, h, w] 这是什么原因呢?

       其实在深度学习和计算机视觉中,将图像张量转换为四维张量的原因主要与批量处理和神经网络的输入要求有关。以下是详细解释:

  1. 批量处理:在训练神经网络时,通常一次处理多个图像,这称为批量处理。批量处理可以提高计算效率,因为在许多情况下,矩阵运算可以在现代硬件(如GPU)上并行执行。因此,将多个图像组合成一个四维张量可以加速训练过程。四维张量的形状为 (B, C, H, W),其中B是批量大小,表示同时处理的图像数量。
  2. 神经网络输入要求:大多数神经网络(尤其是卷积神经网络)在设计时就考虑了批量处理。这意味着它们的输入要求通常是四维张量,即使您只处理单个图像,也需要将其转换为四维张量。在这种情况下,批量大小为1,张量的形状为 (1, C, H, W)。

       通过将图像张量转换为四维张量,您可以确保它与神经网络的输入要求兼容,并充分利用批量处理带来的计算优势。

二、问题:Invalid image shape, we expect BxCxHxW. Got:

   (pip install kornia,需安装kornia库之后使用,具体代码在github上)

     查看ssim源代码如下,需要四维。一般图片转成张量是三维

if not len(image.shape) == 4: raise ValueError(f"Invalid image shape, we expect BxCxHxW. Got: {image.shape}")

       

       图像转tensor张量,需要先加一维,img_tensor1.unsqueeze(0),进行处理后,如果需要显示,再缩减一维才可以显示为图片tensor1.squeeze(0)。

三、使用ssim函数进行两个图片比较的代码如下(可以选择叠加显示两个图,或者单独显示):

picture_show_single()函数帮助将ssim结果图和原图叠加显示。没有此需要的话直接img3.show()即可

  1. import matplotlib.pyplot as plt
  2. from torchvision import transforms
  3. import kornia.metrics as k
  4. import cv2
  5. import numpy as np
  6. plt.rc("font", family='Microsoft YaHei')
  7. # img3.show()
  8. def picture_show_double(img1,img2):
  9. img1=np.array(img1)
  10. img_result= cv2.addWeighted(img1,0.7,img2,0.3,10)
  11. # img_result = cv2.resize(src=img_result, dsize=None, fx=1.5, fy=1.5, interpolation=cv2.INTER_CUBIC)
  12. fig,axes=plt.subplots(nrows=1,ncols=2,figsize=(8,8),dpi=100)
  13. axes[0].set_title("ssim对比结果")
  14. axes[0].imshow(img2[:,:,::-1])
  15. axes[1].set_title("不同位置与原图叠加结果")
  16. axes[1].imshow(img_result[:,:,::-1])
  17. plt.axis('off')
  18. plt.show()
  19. def picture_show_single(img1,img2):
  20. img_result= cv2.addWeighted(img1,0.7,img2,0.3,10)
  21. plt.plot(figsize=(8,8))
  22. plt.imshow(img2[:, :, ::-1])
  23. plt.axis('off')
  24. plt.show()
  25. plt.imshow(img_result[:, :, ::-1])
  26. plt.axis('off')
  27. plt.show()
  28. def ssim_test():
  29. img1 = cv2.imread(r'Figure_1.png')
  30. img2 = cv2.imread(r'Figure_2.png')
  31. # img1=np.array(img1)
  32. print(img1.shape, img2.shape)
  33. transform = transforms.ToTensor()
  34. img_tensor1 = transform(img1)
  35. print(len(img_tensor1.shape))
  36. img_tensor11 = img_tensor1.unsqueeze(0)
  37. print(len(img_tensor11.shape))
  38. img_tensor2 = transform(img2).unsqueeze(0)
  39. print(len(img_tensor2.shape), img_tensor11.shape, img_tensor2.shape)
  40. result_tensor = k.ssim(img_tensor11, img_tensor2, 5)
  41. result_tensor = result_tensor.squeeze(0)
  42. transform = transforms.ToPILImage()
  43. img3 = transform(result_tensor)
  44. img3.show()
  45. img3 = np.array(img3)
  46. picture_show_single(img1, img3)
  47. ssim_test()

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号