当前位置:   article > 正文

Bug解决-RuntimeError: Sizes of tensors must match except in dimension 2. Got 320 and 160 (The offendin_在服务器上运行报错runtimeerror: sizes of tensors must match

在服务器上运行报错runtimeerror: sizes of tensors must match except in dimens

Bug:

RuntimeError: Sizes of tensors must match except in dimension 2. Got 320 and 160 (The offending index is 0)

解决方法:

1.获取Tensor大小

获取 tensor 的方法有两种:shape 和 size(),其中shape是其属性,而 size() 是其继承的方法,两者均可以获得 tensor 的维度。 shape 是属性,使用中括号,size() 是函数,使用小括号

  1. import torch
  2. a = torch.tensor([[1, 2, 3], [4, 5, 6]])
  3. print(a.shape)
  4. print(a.size())

另外,还可以获取其中的某一维度

  1. print(a.shape[0])
  2. print(a.shape[1])
  3. print(a.size(0))
  4. print(a.size(1))

2.了解torch.cat的使用方法

torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。

  1. C = torch.cat( (A,B),0 ) #按维数0拼接(竖着拼)
  2. C = torch.cat( (A,B),1 ) #按维数1拼接(横着拼)
  1. import torch
  2. >>> A=torch.ones(2,3) #2x3的张量(矩阵)
  3. >>> A
  4. tensor([[ 1., 1., 1.],
  5. [ 1., 1., 1.]])
  6. >>> B=2*torch.ones(4,3) #4x3的张量(矩阵)
  7. >>> B
  8. tensor([[ 2., 2., 2.],
  9. [ 2., 2., 2.],
  10. [ 2., 2., 2.],
  11. [ 2., 2., 2.]])
  12. >>> C=torch.cat((A,B),0) #按维数0(行)拼接
  13. >>> C
  14. tensor([[ 1., 1., 1.],
  15. [ 1., 1., 1.],
  16. [ 2., 2., 2.],
  17. [ 2., 2., 2.],
  18. [ 2., 2., 2.],
  19. [ 2., 2., 2.]])
  20. >>> C.size()
  21. torch.Size([6, 3])
  22. >>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
  23. >>> C=torch.cat((A,D),1)#按维数1(列)拼接
  24. >>> C
  25. tensor([[ 1., 1., 1., 2., 2., 2., 2.],
  26. [ 1., 1., 1., 2., 2., 2., 2.]])
  27. >>> C.size()
  28. torch.Size([2, 7])

上面给出了两个张量A和B,分别是2行3列,4行3列。即他们都是2维张量。因为只有两维,这样在用torch.cat拼接的时候就有两种拼接方式:按行拼接和按列拼接。即所谓的维数0和维数1. 

C=torch.cat((A,B),0)就表示按维数0(行)拼接A和B,也就是竖着拼接,A上B下。此时需要注意:列数必须一致,即维数1数值要相同,这里都是3列,方能列对齐。拼接后的C的第0维是两个维数0数值和,即2+4=6.

C=torch.cat((A,B),1)就表示按维数1(列)拼接A和B,也就是横着拼接,A左B右。此时需要注意:行数必须一致,即维数0数值要相同,这里都是2行,方能行对齐。拼接后的C的第1维是两个维数1数值和,即3+4=7.

从2维例子可以看出,使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐

实例:在深度学习处理图像时,常用的有3通道的RGB彩色图像及单通道的灰度图。张量size为cxhxw,即通道数x图像高度x图像宽度。在用torch.cat拼接两张图像时一般要求图像大小一致而通道数可不一致,即h和w同,c可不同。当然实际有3种拼接方式,另两种好像不常见。比如经典网络结构:U-Net. 里面用到4次torch.cat,其中copy and crop操作就是通过torch.cat来实现的。可以看到通过上采样(up-conv 2x2)将原始图像h和w变为原来2倍,再和左边直接copy过来的同样h,w的图像拼接。这样做,可以有效利用原始结构信息。

参考链接:Pytorch中的torch.cat()函数_荷叶田田_的博客-CSDN博客_python torch.cat

3.使用PixelShuffle上采样图像

 

 上图中左侧第一部分是用于对图像的特征进行抽取。而后在倒数第二层生成 r*r 个通道特征图,这里 r 就是希望上采样的倍数pixelshuffle的主要功能就是将这 r*r 个通道的特征图组合为新的 w∗r,h∗r 的上采样结果。具体来说,就是将原来一个低分辨的像素划分为rr个更小的格子,利用rr个特征图对应位置的值按照一定的规则来填充这些小格子。按照同样的规则将每个低分辨像素划分出的小格子填满就完成了重组过程。在这一过程中模型可以调整 r*r 个shuffle通道权重 不断优化生成的结果。主要实现了这样的功能:N * (C * r * r) * W * H ——>> N * C * (H * r) * (W * r)

代码实现:

  1. import torch
  2. a = torch.arange(36).reshape([1, 4, 3, 3])
  3. b = torch.pixel_shuffle(a, 2)
  4. print(a.shape) #torch.Size([1, 4, 3, 3])
  5. print(b.shape) #torch.Size([1, 1, 6, 6])

出现Bug:

RuntimeError: c % upscale_factor_squared == 0 INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1607370117127/work/aten/src/ATen/native/PixelShuffle.cpp":24, please report a bug to PyTorch. pixel_shuffle expects input channel to be divisible by square of upscale_factor, but got input with sizes [4, 2, 160, 160], upscale_factor=2, and self.size(1)=2 is not divisible by 4

解决方法:使用

m=nn.Upsample(scale_factor=2)

4.尺度一致后 torch.cat不会报错

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/149585
推荐阅读
相关标签
  

闽ICP备14008679号