当前位置:   article > 正文

torch与numpy转换内存上的小坑_torch.from_numpy浮点数精度变小了

torch.from_numpy浮点数精度变小了

在学习机器学习中,tensor的各种转换是新手容易遇到的坑 ,我这里记录一下我遇到的一些坑

  1. 将numpy数据类型转换成Tensor

  1. a = torch.ones(5)
  2. b = a.numpy()
  3. a.add_(1) # 就地版本的add()
  4. print(a)
  5. print(b)
  1. tensor([2., 2., 2., 2., 2.])
  2. [2. 2. 2. 2. 2.]

torch中的add_()是就地版本的add(),这样b的值会随a变化,而若使用add() 则b的值是全1

  1. numpy数组转化成Torch的Tensor

  1. import numpy as np
  2. a = np.ones(5)
  3. b = torch.from_numpy(a)
  4. np.add(a,1,out=a)
  5. print(a)
  6. print(b)
  1. [2. 2. 2. 2. 2.]
  2. tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

torch.from_numpy()会创建一个tensor从numpy转化来的,返回的tensor和之前的narray共享内存,下面是官方的解释:

Creates a Tensor from a numpy.ndarray.

The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.

  1. 复制tensor数据类型的数据时候

  1. x = torch.arange(12)
  2. y = torch.tensor(x)

这样复制时会报UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). y = torch.tensor(x)的警告

  1. x = torch.arange(12)
  2. y = torch.as_tensor(x)

使用as_tensor做复制可以不报错,是官方推荐的写法

  1. torch的reshape()是返回的一个view

  1. a = torch.arange(12)
  2. b = a.reshape((3,4))
  3. b[:] = 2
  4. a
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/272366
推荐阅读
相关标签
  

闽ICP备14008679号