赞
踩
这是一个适合PyTorch入门者看的博客。PyTorch的文档质量比较高,入门较为容易,这篇博客选取官方链接里面的例子,介绍如何用PyTorch训练一个ResNet模型用于图像分类,代码逻辑非常清晰,基本上和许多深度学习框架的代码思路类似,非常适合初学者想上手PyTorch训练模型(不必每次都跑mnist的demo了)。接下来从个人使用角度加以解释。解释的思路是从数据导入开始到模型训练结束,基本上就是搭积木的方式来写代码。
首先是数据导入部分,这里采用官方写好的torchvision.datasets.ImageFolder接口实现数据导入。这个接口需要你提供图像所在的文件夹,就是下面的data_dir=‘/data’这句,然后对于一个分类问题,这里data_dir目录下一般包括两个文件夹:train和val,每个文件件下面包含N个子文件夹,N是你的分类类别数,且每个子文件夹里存放的就是这个类别的图像。这样torchvision.datasets.ImageFolder就会返回一个列表(比如下面代码中的image_datasets[‘train’]或者image_datasets[‘val]),列表中的每个值都是一个tuple,每个tuple包含图像和标签信息。
data_dir = '/data'
image_datasets = {x: datasets.ImageFolder(
os.path.join(data_dir, x),
data_transforms[x]),
for x in ['train', 'val']}
另外这里的data_transforms是一个字典,如下。主要是进行一些图像预处理,比如resize、crop等。实现的时候采用的是torchvision.transforms模块,比如torchvision.transforms.Compose是用来管理所有transforms操作的,torchvision.transforms.RandomSizedCrop是做crop的。需要注意的是对于torchvision.transforms.RandomSizedCrop和transforms.RandomHorizontalFlip()等,输入对象都是PIL Image,也就是用python的PIL库读进来的图像内容,而transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])的作用对象需要是一个Tensor,因此在transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])之前有一个 transforms.ToTensor()就是用来生成Tensor的。另外transforms.Scale(256)其实就是resize操作,目前已经被transforms.Resize类取代了。
data_transforms = {
'train
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。