当前位置:   article > 正文

PyTorch 笔记(03)— Tensor 数据类型分类(默认数据类型、CPU tensor、GPU tensor、CPU 和 GPU 之间的转换、数据类型之间转换)_tensor类型

tensor类型

1. Tensor 数据类型

Tensor 有不同的数据类型,如下表所示,每种类型都有 CPUGPU 版本(HalfTensor)除外,默认的 tensor 是数据类型是 FloatTensor,只能通过 t.set_default_tensor_type 修改 tensor 为浮点类型,(如果默认类型为 GPU tensor,则所有的操作都在 GPU 上进行)。

获取 torch 默认的数据类型。

In [113]: import torch as t

In [114]: t.get_default_dtype()                                                                                                                                      
Out[114]: torch.float32
  • 1
  • 2
  • 3
  • 4

HalfTensor 是专门为 GPU 版本设计的,同样的元素个数显存只有 FloatTensor 的一半,可以缓解 GPU 显存不足问题,但由于 HalfTensor 能表示的数值大小和精度有限,有可能出现溢出等问题。

Tensor数据类型

使用 t.set_default_tensor_type 将默认数据类型修改为 IntTensor 时会报错。

In [124]: t.set_default_tensor_type(t.IntTensor)                                                                                                                     
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-124-434c3566c688> in <module>
----> 1 t.set_default_tensor_type(t.IntTensor)

/usr/local/lib/python3.6/dist-packages/torch/__init__.py in set_default_tensor_type(t)
    204     if isinstance(t, _string_classes):
    205         t = _import_dotted_name(t)
--> 206     _C._set_default_tensor_type(t)
    207 
    208 

TypeError: only floating-point types are supported as the default type
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

只能设置为 FloatTensor 类型。

In [125]: t.set_default_tensor_type(t.FloatTensor) 
  • 1

1.1 torch.FloatTensor

用于生成数据类型为浮点型Tensor,传递给 torch.FloatTensor 的参数可以是一个列表,也可以是一个维度值。

import torch

a = torch.FloatTensor(2, 3)
b = torch.FloatTensor([1, 2, 3, 4])

  • 1
  • 2
  • 3
  • 4
  • 5

输出结果:

tensor([[2.3489e-37, 4.5835e-41, 2.3489e-37],
        [4.5835e-41, 4.4842e-44, 0.0000e+00]])
        
tensor([1., 2., 3., 4.])
  • 1
  • 2
  • 3
  • 4

可以看到,打印输出的两组变量数据类型都显示为浮点型,不同的是,前面的一组是按照我们指定的维度随机生成的浮点型 Tensor 而另外一组是按我们给定的列表生成的浮点型 Tensor

1.2 torch.IntTensor

用于生成数据类型为整型的 Tensor。传递给 torch.IntTensor 的参数可以是一个列表,也可以是一个维度值。

import torch

a = torch.IntTensor(2, 3)
b = torch.IntTensor([1, 2, 3, 4])
  • 1
  • 2
  • 3
  • 4

输出结果:

tensor([[1491430264,      32561, 1491430264],
        [     32561,  808464432,  808463205]], dtype=torch.int32)
tensor([1, 2, 3, 4], dtype=torch.int32)
  • 1
  • 2
  • 3

可以看出输出的数据类型都为整形(torch.int32

2. 数据类型之间转换

各数据类型之间可以相互转换,type(new_type) 是通用的做法,同时还有 float/long/half 等快捷方法。

In [101]: a = t.ones(2,3)                                                                                                                                            

In [102]: a                                                                                                                                                          
Out[102]: 
tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [103]: a.type()                                                                                                                                                   
Out[103]: 'torch.FloatTensor'

In [104]: b = a                                                                                                                                                      

In [105]: b.int()                                                                                                                                                    
Out[105]: 
tensor([[1, 1, 1],
        [1, 1, 1]], dtype=torch.int32)

In [106]: a.type(t.IntTensor)      # 等价于 a.int()                                                                                                                                  
Out[106]: 
tensor([[1, 1, 1],
        [1, 1, 1]], dtype=torch.int32)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

3. CPU 和 GPU 之间转换

CPU tensorGPU tensor 之间的互相转换可以通过 tensor.cudatensor.cpu 方法来实现。

In [115]: a = t.ones(2,3)                                                                                                                                            

In [116]: a.type()                                                                                                                                                   
Out[116]: 'torch.FloatTensor'

In [117]: a                                                                                                                                                          
Out[117]: 
tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [118]: a.cuda()                                                                                                                                                   
Out[118]: 
tensor([[1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')

In [119]: b = a.cuda()                                                                                                                                               

In [120]: b                                                                                                                                                          
Out[120]: 
tensor([[1., 1., 1.],
        [1., 1., 1.]], device='cuda:0')

In [121]: b.type()                                                                                                                                                   
Out[121]: 'torch.cuda.FloatTensor'

In [122]: b.cpu()                                                                                                                                                    
Out[122]: 
tensor([[1., 1., 1.],
        [1., 1., 1.]])

In [123]: b.cpu().type()                                                                                                                                             
Out[123]: 'torch.FloatTensor'

In [124]: 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/272512
推荐阅读
相关标签
  

闽ICP备14008679号