当前位置:   article > 正文

PyTorch基础知识(超基础)_pytorch框架

pytorch框架

一、PyTorch框架介绍

PyTorch是在2017年1月由Facebook推出的。它是经典机器学习库Torch框架的一个端口,主要编程语言为python。PyTorch“曾经”的优点是动态图,现在的优点是开源代码和开源社区。

PyTorch是一个年轻的框架。2017年1月28日,PyTorch 0.1版本正式发布,这是Facebook公司在机器学习和科学计算工具Torch的基础上,针对Python语言发布的全新的深度学习工具包。PyTorch类似NumPy,并且支持GPU,有着更高级而又易用的功能,可以用来快捷地构建和训练深度神经网络。一经发布,PyTorch便受到深度学习和开发者们广泛关注和讨论。经过一年多的发展,目前PyTorch已经成为机器学习和深度学习者重要的研究和开发工具之一。

二、tensor(张量)的属性

张量是PyTorch里的基本运算单位,与numpy的ndarray相同都表示一个多维的矩阵。与ndarray最大的区别在于Tensor能使用GPU加速,而ndarray只能用在CPU上。

1、导入torch库

  1. import torch
  2. a = torch.tensor([1,2,3],dtype=int)
  3. b = torch.tensor([4,5,6],dtype=float)
  4. print(a)
  5. print(a.dtype)
  6. print(b)
  7. //结果
  8. //tensor([1, 2, 3])
  9. //torch.int64 64位的int类型
  10. //tensor([4., 5., 6.], dtype=torch.float64)

2、数据的维度

tensor = torch.tensor([[1,2,3],[4,5,6]])
print(tensor.ndim)
//结果
//2

3、数据的形状

tensor = torch.tensor([[1,2,3],[4,5,6]])
print(tensor.shape)
print(tensor.size())
//结果
//torch.Size([2, 3])
//torch.Size([2, 3])

4、基础运算

  1. sample = torch.rand(3, 2)
  2. print(sample)
  3. # 求总和
  4. print(torch.sum(sample))
  5. # 求最小值
  6. print(torch.min(sample))
  7. # 求最小值所在的位置(索引)
  8. print(torch.argmin(sample))
  9. # 求最大值所在的位置(索引)
  10. print(torch.argmax(sample))
  11. # 求平均值
  12. print(torch.mean(sample))
  13. # 求中位数
  14. print(torch.median(sample))
  15. # 求开方
  16. print(torch.sqrt(sample))
  17. # 求平方
  18. print(sample ** 2)
  19. //结果
  20. tensor([[0.9740, 0.4381],
  21. [0.9116, 0.7061],
  22. [0.9465, 0.0322]])
  23. tensor(4.0085)
  24. tensor(0.0322)
  25. tensor(5)
  26. tensor(0)
  27. tensor(0.6681)
  28. tensor(0.7061)
  29. tensor([[0.9869, 0.6619],
  30. [0.9548, 0.8403],
  31. [0.9729, 0.1793]])
  32. tensor([[0.9487, 0.1919],
  33. [0.8310, 0.4986],
  34. [0.8960, 0.0010]])

三、数据生成

1、生成数据集

  1. import torch
  2. a = torch.tensor([1,2,3],dtype=int)
  3. b = torch.tensor([4,5,6],dtype=float)
  4. print(a)
  5. print(a.dtype)
  6. print(b)
  7. //结果
  8. //tensor([1, 2, 3])
  9. //torch.int64 64位的int类型
  10. //tensor([4., 5., 6.], dtype=torch.float64)

2、读取数据集

  1. def load_array(data_arrays, batch_size, is_train=True): #@save
  2. # 构造一个PyTorch数据迭代器
  3. # torch.utils.data.TensorDataset(Dataset):包装数据和目标张量的数据集,直接传入张量参数即可(*tensor)
  4. # 该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。
  5. dataset = data.TensorDataset(*data_arrays)
  6. # 通过DataLoader类从dataset中选取数据,每次读取大小为batch_size
  7. return data.DataLoader(dataset, batch_size, shuffle=is_train)
  8. batch_size = 10
  9. # 传入张量数组,返回一个可迭代的张量组
  10. data_iter = load_array((features, labels), batch_size)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/385651
推荐阅读
相关标签
  

闽ICP备14008679号