当前位置:   article > 正文

Pytorch:torch.flatten()与torch.nn.Flatten()_torch.randn(32, 1, 5, 5)

torch.randn(32, 1, 5, 5)

 torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化。

  1. import torch
  2. x=torch.randn(2,4,2)
  3. print(x)
  4. z=torch.flatten(x)
  5. print(z)
  6. w=torch.flatten(x,1)
  7. print(w)
  8. 输出为:
  9. tensor([[[-0.9814, 0.8251],
  10. [ 0.8197, -1.0426],
  11. [-0.8185, -1.3367],
  12. [-0.6293, 0.6714]],
  13. [[-0.5973, -0.0944],
  14. [ 0.3720, 0.0672],
  15. [ 0.2681, 1.8025],
  16. [-0.0606, 0.4855]]])
  17. tensor([-0.9814, 0.8251, 0.8197, -1.0426, -0.8185, -1.3367, -0.6293, 0.6714,
  18. -0.5973, -0.0944, 0.3720, 0.0672, 0.2681, 1.8025, -0.0606, 0.4855])
  19. tensor([[-0.9814, 0.8251, 0.8197, -1.0426, -0.8185, -1.3367, -0.6293, 0.6714]
  20. ,
  21. [-0.5973, -0.0944, 0.3720, 0.0672, 0.2681, 1.8025, -0.0606, 0.4855]
  22. ])

 torch.flatten(x,0,1)代表在第一维和第二维之间平坦化

  1. import torch
  2. x=torch.randn(2,4,2)
  3. print(x)
  4. w=torch.flatten(x,0,1) #第一维长度2,第二维长度为4,平坦化后长度为2*4
  5. print(w.shape)
  6. print(w)
  7. 输出为:
  8. tensor([[[-0.5523, -0.1132],
  9. [-2.2659, -0.0316],
  10. [ 0.1372, -0.8486],
  11. [-0.3593, -0.2622]],
  12. [[-0.9130, 1.0038],
  13. [-0.3996, 0.4934],
  14. [ 1.7269, 0.8215],
  15. [ 0.1207, -0.9590]]])
  16. torch.Size([8, 2])
  17. tensor([[-0.5523, -0.1132],
  18. [-2.2659, -0.0316],
  19. [ 0.1372, -0.8486],
  20. [-0.3593, -0.2622],
  21. [-0.9130, 1.0038],
  22. [-0.3996, 0.4934],
  23. [ 1.7269, 0.8215],
  24. [ 0.1207, -0.9590]])

 对于torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。

  1. import torch
  2. #随机32个通道为15*5的图
  3. x=torch.randn(32,1,5,5)
  4. model=torch.nn.Sequential(
  5. #输入通道为1,输出通道为63*3的卷积核,步长为1,padding=1
  6. torch.nn.Conv2d(1,6,3,1,1),
  7. torch.nn.Flatten()
  8. )
  9. output=model(x)
  10. print(output.shape) # 6*7-3+1*7-3+1
  11. 输出为:
  12. torch.Size([32, 150])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/1007146
推荐阅读
相关标签
  

闽ICP备14008679号