当前位置:   article > 正文

PyTorch中view()函数用法说明

PyTorch中view()函数用法说明

首先,view( ) 是对 PyTorch 中的 Tensor 操作的,若非 Tensor 类型,可使用 data = torch.tensor(data)来进行转换。

(1) 作用:该函数返回一个有相同数据但不同维度大小的 Tensor。也就是说该函数的功能是改变矩阵维度,相当于 Numpy 中的 resize() 或者 Tensorflow 中的 reshape() 。

(2) 参数:view( *args )

  1. import torch
  2. x = torch.randn(6, 6)
  3. print(x.size())
  4. y = x.view(36)
  5. print(y.size())
  6. z = x.view(-1, 9) # -1表示该维度取决于其它维度大小,即(6*6)/ 9
  7. print(z.size())
  8. m = x.view(3, 3, 4) # 也可以变为更多维度
  9. print(m.size())
  10. 输出:
  11. torch.Size([6, 6])
  12. torch.Size([36])
  13. torch.Size([4, 9])
  14. torch.Size([3, 3, 4])

特殊用法view(-1)

若需要转换维度为一维,有一种简单的方式,即将参数设置为 -1

  1. import torch
  2. a = torch.Tensor([[1, 2, 3], [4, 5, 6],[7,8,9]]) # 定义一个 2*3 的 Tensor
  3. a = a.view(-1)
  4. print(a)
  5. 输出:
  6. tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/371790?site
推荐阅读
相关标签
  

闽ICP备14008679号