赞
踩
其他笔记在专栏 深度学习 中。
a = torch.ones(5)
b = a.numpy()
print(type(a), type(b)) #a是Tensor,b是numpy
print(a, b)
a += 1
print(a, b)
b += 1
print(a, b)
<class 'torch.Tensor'> <class 'numpy.ndarray'>
tensor([1., 1., 1., 1., 1.]) [1. 1. 1. 1. 1.]
tensor([2., 2., 2., 2., 2.]) [2. 2. 2. 2. 2.]
tensor([3., 3., 3., 3., 3.]) [3. 3. 3. 3. 3.]
c = torch.from_numpy(b)
print(b, c)
b += 1
print(b, c)
c += 1
print(b, c)
[3. 3. 3. 3. 3.] tensor([3., 3., 3., 3., 3.])
[4. 4. 4. 4. 4.] tensor([4., 4., 4., 4., 4.])
[5. 5. 5. 5. 5.] tensor([5., 5., 5., 5., 5.])
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
print(a, b)
a += 1
print(a, b)
b += 1
print(a, b)
[1. 1. 1. 1. 1.] tensor([1., 1., 1., 1., 1.], dtype=torch.float64)
[2. 2. 2. 2. 2.] tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
[3. 3. 3. 3. 3.] tensor([3., 3., 3., 3., 3.], dtype=torch.float64)
该方法总是会进行数据拷贝,返回的Tensor和原来的数据不再共享内存。
c = torch.tensor(a)
a += 1
print(a, c)
[4. 4. 4. 4. 4.] tensor([3., 3., 3., 3., 3.], dtype=torch.float64)
x = torch.tensor([3.5])
print(x, x.item(), float(a))
tensor([3.5000]) 3.5 3.5 3
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。