当前位置:   article > 正文

【python】图片numpy和pytorch tensor的互相转换_查看图片数据格式是tensor还是numpypython

查看图片数据格式是tensor还是numpypython


numpy和tensor互相转换

用于记录python在深度学习中的使用
经常使用到的场景是从本地opencv读取一张图片,需要将图片由numpy格式转为tensor格式,并经过一些通道变换(h,w,c)->(b,c,w,h) 。转为pytorch可以推理的张量tensor。
推理结束之后还需要将tensor转为numpy数据,用于图片的保存或者显示。

def tensor2img(tensor):
    
    img = tensor.squeeze(0).permute(1,2,0).numpy() #转为tensor
    img = (img[:,:,::-1]*255.0).astype(np.uint8)   #rgb格式转为bgr格式,并转为0-255 int8格式
    return img #(224,224,3) 图片格式

def img2tensor(img):
    img = (img[:,:,::-1].astype(np.float32))/255.0 #bgr格式转为rgb格式,并转为0-1 float32格式
    img = np.transpose(img,(2,0,1)) # (224,224,3) 转为(3,224,224)
    img = img[np.newaxis,:,:,:].copy()  #(1,3,224,224)
    tensor = torch.from_numpy(img)
    return tensor  #(1,3,224,224) 模型输入格式
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

上面代码为img和tensor的具体转换实现。

其中代码的transpose、permute、unsqueeze、squeeze、np.newaxis都是维度的变化

transpose/permute

transpose 可以理解为修改维度的先后顺序,例如将第三维度拉到第一维度,拉完之后在相应维度上的信息没有发生改变。
numpy 的transpose可以变换多个维度
torch tensor 的transpose 只能切换两个维度,permute才能切换多个维度

numpy

# numpy 
num = np.random.randn(3,4,5,6)

##两种实现方式
tran = num.transpose(1,2,3,0)  #维度变成(4,5,6,3)
tran2 = np.transpose(num,(1,2,3,0)) #维度变成(4,5,6,3)
print(tran[:,:,:,0]==(num[0,:,:,:])) #相等
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

结果

[[[ True  True  True  True  True  True]
  [ True  True  True  True  True  True]
  [ True  True  True  True  True  True]
  [ True  True  True  True  True  True]
  [ True  True  True  True  True  True]]...
  • 1
  • 2
  • 3
  • 4
  • 5

tensor

# torch
num_tensor = torch.randn(3,4,5,6)
#tran3 = num_tensor.transpose(3,0,1,2) #会出错 
tran3 = num_tensor.permute(3,0,1,2) #会出错 
print(num_tensor[:,:,:,0]==(tran3[0,:,:,:])) #相等
  • 1
  • 2
  • 3
  • 4
  • 5

结果

tensor([[[True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True],
         [True, True, True, True, True]],....
  • 1
  • 2
  • 3
  • 4

维度变化之unsqueeze,squeeze,np.newaxis

numpy

img = np.random.randn(224,224,3)
print(img.shape)
img = img[:,:,:,np.newaxis]
print(img.shape)
  • 1
  • 2
  • 3
  • 4

结果

(224, 224, 3)
(224, 224, 3, 1)
  • 1
  • 2

tensor

tensor = torch.randn(3,224,224)
tensor = tensor.unsqueeze(0) #在第一个扩展一个维度
print(tensor.shape)
tensor = tensor.squeeze(0) #在第一个减少一个维度
print(tensor.shape)
  • 1
  • 2
  • 3
  • 4
  • 5

结果

torch.Size([1, 3, 224, 224])
torch.Size([3, 224, 224])
  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/219414
推荐阅读
相关标签
  

闽ICP备14008679号