当前位置:   article > 正文

(系列更新完毕)深度学习零基础使用 PyTorch 框架跑 MNIST 数据集的第二天:加载 MNIST 数据集_pythorch跑mnist

pythorch跑mnist

1. Introduction

今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第二天,主要学习加载 MNIST 数据集。本 blog 主要记录一个学习的路径以及学习资料的汇总。

注意:这是用 Python 2.7 版本写的代码

第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108098147

第二天(加载 MNIST 数据集):https://blog.csdn.net/qq_36627158/article/details/108119048

第三天(训练模型):https://blog.csdn.net/qq_36627158/article/details/108163693

第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108183655

 

 

 

2. Code(lenet.py)

感谢 凯神 提供的代码与耐心指导!

  1. import torchvision.transforms as transforms
  2. from torch.utils.data import Dataset, DataLoader
  3. import glob
  4. import os.path as osp
  5. from PIL import Image
  6. import matplotlib.pyplot as plt
  7. TRAIN_BATCH_SIZE = 128
  8. TEST_BATCH_SIZE = 1000
  9. class MNIST(Dataset): # define a class named MNIST
  10. # read all pictures' filename
  11. def __init__(self, root, transform=None):
  12. self.filenames = []
  13. self.transform = transform
  14. # read filenames
  15. for i in range(10):
  16. # 'root/0/all_png'
  17. filenames = glob.glob(osp.join(root, str(i), '*.png'))
  18. for fn in filenames:
  19. # (filename, label)
  20. self.filenames.append((fn, i))
  21. self.len = len(self.filenames)
  22. # Get a sample from the dataset
  23. # Return an image and it's label
  24. def __getitem__(self, index):
  25. # open the image
  26. image_fn, label = self.filenames[index]
  27. image = Image.open(image_fn)
  28. # May use transform function to transform samples
  29. if self.transform is not None:
  30. image = self.transform(image)
  31. return image, label
  32. # get the length of dataset
  33. def __len__(self):
  34. return self.len
  35. # define the transformation
  36. # PIL images -> torch tensors [0, 1]
  37. transform = transforms.Compose([
  38. transforms.ToTensor()
  39. ])
  40. # 2. load the MNIST training dataset
  41. trainset = MNIST(
  42. root='/home/ubuntu/Downloads/C6/mnist_png/training',
  43. transform=transform
  44. )
  45. # divide the dataset into batches
  46. trainset_loader = DataLoader(
  47. trainset,
  48. batch_size=TRAIN_BATCH_SIZE,
  49. shuffle=True,
  50. num_workers=0
  51. )
  52. # 3. load the MNIST testing dataset
  53. testset = MNIST(
  54. root='/home/ubuntu/Downloads/C6/mnist_png/testing',
  55. transform=transform
  56. )
  57. # divide the dataset into batches
  58. testset_loader = DataLoader(
  59. testset,
  60. batch_size=TEST_BATCH_SIZE,
  61. shuffle=False,
  62. num_workers=0
  63. )

 

 

 

3. Materials

1、Dataset 的抽象类官方文档:

https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/

 

2、DataLoader 类的官方文档:

https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

 

 

 

4、Code Details

1、__init__() 函数

注意:__init__并不相当于C#中的构造函数,执行它的时候,实例已构造出来了。__init__作用是初始化已实例化后的对象

图文均来自链接:https://www.cnblogs.com/insane-Mr-Li/p/9758776.html

 

2、Dataset 的子类都应该要重写 __len__() 和 __getitem__() 函数。前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

之前看代码,一直没有看到具体体现 __getitem__() 函数的使用地方。

后面查到了:只要继承了 Dataset 这个类后,就可以通过类的实例化对象的索引来调用到 _getitem_() 了。如: data[0]

https://www.zhihu.com/question/383099903

(图也是链接里的)

 

3、enumerate() 函数

将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据数据下标,一般用在 for 循环当中。

  1. seq = ['one', 'two', 'three']
  2. for i, element in enumerate(seq):
  3. print i, element
  4. # 0 one
  5. # 1 two
  6. # 2 three

https://www.runoob.com/python/python-func-enumerate.html

 

4、Batch Size

Batch Size的理解:https://blog.csdn.net/qq_34886403/article/details/82558399

batch size 设置技巧:https://blog.csdn.net/kl1411/article/details/82983971

顺便找到了一个小白科普贴:深度学习中GPU和显存分析

 

5、Dataloader 中的 num_worker

https://www.cnblogs.com/hesse-summer/p/11343870.html

https://blog.csdn.net/breeze210/article/details/99679048

 

6、迭代器(iterator)

迭代是Python最强大的功能之一,是访问集合元素的一种方式。

迭代器是一个可以记住遍历的位置的对象。

迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。

迭代器有两个基本的方法:iter() 和 next()

https://www.runoob.com/python3/python3-iterator-generator.html

 

7、DataLoader, DataSet, Sampler之间的关系

https://zhuanlan.zhihu.com/p/76893455

 

8、DataLoader 的索引

  • dataloader本质是一个可迭代对象,使用 iter() 访问,不能使用 next() 访问
  • 使用 iter(dataloader) 返回的是一个迭代器,然后可以使用next访问
  • 也可以使用 for inputs, labels in dataloaders 进行可迭代对象的访问
  • 一般我们实现一个datasets对象,传入到dataloader中;然后内部使用yeild返回每一次batch的数据

https://www.cnblogs.com/ranjiewen/p/10128046.html

 

9、Python glob.glob使用

https://www.cnblogs.com/luminousjj/p/9359543.html

https://www.cnblogs.com/luminousjj/p/9359543.html

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号