赞
踩
在学习机器学习中,tensor的各种转换是新手容易遇到的坑 ,我这里记录一下我遇到的一些坑
将numpy数据类型转换成Tensor
- a = torch.ones(5)
- b = a.numpy()
- a.add_(1) # 就地版本的add()
- print(a)
- print(b)
- tensor([2., 2., 2., 2., 2.])
- [2. 2. 2. 2. 2.]
torch中的add_()是就地版本的add(),这样b的值会随a变化,而若使用add() 则b的值是全1
将numpy数组转化成Torch的Tensor
- import numpy as np
- a = np.ones(5)
- b = torch.from_numpy(a)
- np.add(a,1,out=a)
- print(a)
- print(b)
- [2. 2. 2. 2. 2.]
- tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
torch.from_numpy()会创建一个tensor从numpy转化来的,返回的tensor和之前的narray共享内存,下面是官方的解释:
Creates a Tensor from a numpy.ndarray.
The returned tensor and ndarray share the same memory. Modifications to the tensor will be reflected in the ndarray and vice versa. The returned tensor is not resizable.
复制tensor数据类型的数据时候
- x = torch.arange(12)
- y = torch.tensor(x)
这样复制时会报UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). y = torch.tensor(x)的警告
- x = torch.arange(12)
- y = torch.as_tensor(x)
使用as_tensor做复制可以不报错,是官方推荐的写法
torch的reshape()是返回的一个view
- a = torch.arange(12)
- b = a.reshape((3,4))
- b[:] = 2
- a
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。