当前位置:   article > 正文

深度学习——图像分类模型最简单的train.py_图像分类深度学习模型

图像分类深度学习模型
  1. import os
  2. import sys
  3. import json
  4. import torch
  5. import torch.nn as nn
  6. from torchvision import transforms, datasets
  7. import torch.optim as optim
  8. from tqdm import tqdm
  9. #from classic_models.alexnet import AlexNet
  10. from classic_models.googlenet_v1 import GoogLeNet
  11. def main():
  12. # 判断可用设备
  13. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  14. print("using {} device.".format(device))
  15. # 注意改成自己的数据集路径
  16. data_path = "G:\\flower"
  17. assert os.path.exists(data_path), "{} path does not exist.".format(data_path)
  18. # 数据预处理与增强
  19. """
  20. ToTensor()能够把灰度范围从0-255变换到0-1之间的张量.
  21. transform.Normalize()则把0-1变换到(-1,1). 具体地说, 对每个通道而言, Normalize执行以下操作: image=(image-mean)/std
  22. 其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定。原来的0-1最小值0则变成(0-0.5)/0.5=-1; 而最大值1则变成(1-0.5)/0.5=1.
  23. 也就是一个均值为0, 方差为1的正态分布. 这样的数据输入格式可以使神经网络更快收敛。
  24. """
  25. data_transform = {
  26. "train": transforms.Compose([transforms.Resize(224), # 将图片的短边缩放到224,图片的长边和短边的比值不变,即不能保证每张图片都是224*224大小,那么下一步的裁剪就有必要了
  27. transforms.CenterCrop(224), # 由中心向两边进行裁剪,裁剪的尺寸为224*224
  28. transforms.ToTensor(), # 可以将PIL和numpy格式的数据从[0,255]范围转换到[0,1] 。另外原始数据的shape是(H x W x C),这步后shape会变为(C x H x W)
  29. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), # Normalize(mean, std, inplace=False),三通道中Normalize里面一般是Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),将上一步的数据范围由[0,1]转换为[-1,1]
  30. "val": transforms.Compose([transforms.Resize((224, 224)), # val不需要任何数据增强
  31. transforms.ToTensor(),
  32. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
  33. # 使用ImageFlolder加载数据集中的图像,并使用指定的预处理操作来处理图像, ImageFlolder会同时返回图像和对应的标签。 (image path, class_index) tuples
  34. train_dataset = datasets.ImageFolder(root=os.path.join(data_path, "train"), transform=data_transform["train"]) # root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
  35. validate_dataset = datasets.ImageFolder(root=os.path.join(data_path, "val"), transform=data_transform["val"]) # transform:对图片进行预处理的操作(函数)。在data_transform中已经定义好
  36. train_num = len(train_dataset) # 计算train_dataset里面的图片个数
  37. val_num = len(validate_dataset) # 计算validate_dataset里面的图片个数
  38. # 使用class_to_idx给类别一个index,作为训练时的标签: {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  39. flower_list = train_dataset.class_to_idx # class_to_idx是train_dataset里面的一个函数,返回一个字典,即flower_list是一个字典
  40. # 创建一个字典,存储index和类别的对应关系,在模型推理阶段会用到。
  41. cla_dict = dict((val, key) for key, val in flower_list.items()) # items()方法将字典里对应的一对键和值以元组的形式(键, 值),存储为所生成序列里的单个元素
  42. # 将字典写成一个json文件
  43. json_str = json.dumps(cla_dict, indent=4) # json.dumps()是把python对象转换成json对象的一个过程,生成的是字符串。
  44. with open(os.path.join(data_path, 'class_indices.json') , 'w') as json_file:
  45. json_file.write(json_str)
  46. batch_size = 32 # batch_size大小,是超参,可调,如果模型跑不起来,尝试调小batch_size
  47. # 使用 DataLoader 将 ImageFloder 加载的数据集处理成批量(batch)加载模式
  48. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True )
  49. validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=False ) # 注意,验证集不需要shuffle
  50. print("using {} images for training, {} images for validation.".format(train_num, val_num))
  51. # 实例化模型,并送进设备
  52. net = GoogLeNet(num_classes = 5) # 使用GoogleNet来定义网络模型,分类数为5
  53. #net = AlexNet(num_classes=5 )
  54. net.to(device)
  55. # 指定损失函数用于计算损失;指定优化器用于更新模型参数;指定训练迭代的轮数,训练权重的存储地址
  56. loss_function = nn.CrossEntropyLoss() # 交叉熵函数
  57. optimizer = optim.Adam(net.parameters(), lr=0.0002)
  58. epochs = 1
  59. save_path = os.path.abspath(os.path.join(os.getcwd(), './results/weights/alexnet')) # os.getcwd()返回当前的文件目录,也就是后面的文件目录
  60. if not os.path.exists(save_path):
  61. os.makedirs(save_path) # 创建名为save_path的目录
  62. best_acc = 0.0 # 初始化验证集上最好的准确率,以便后面用该指标筛选模型最优参数。
  63. for epoch in range(epochs):
  64. ############################################################## train ######################################################
  65. net.train()
  66. acc_num = torch.zeros(1).to(device) # 初始化,用于计算训练过程中预测正确的数量
  67. sample_num = 0 # 初始化,用于记录当前迭代中,已经计算了多少个样本
  68. # tqdm是一个进度条显示器,可以在终端打印出现在的训练进度
  69. # train_loader:是需要迭代的对象,通常为列表或者生成器,其中包含训练数据。进度条会遍历该对象,并相应地更新进度。
  70. # file=sys.stdout:这个参数指定了进度条应该写入其输出的位置。在这种情况下,它被设置为sys.stdout,表示标准输出流(通常是控制台)。因此,进度条将显示在控制台中。
  71. # ncols=100:这个参数设置进度条的宽度,以字符为单位。在这里,它被设置为100个字符
  72. train_bar = tqdm(train_loader, file=sys.stdout, ncols=100)
  73. for data in train_bar :
  74. images, labels = data
  75. sample_num += images.shape[0] #[32, 3, 224, 224]
  76. optimizer.zero_grad() # 梯度初始化为零,把loss关于weight的导数变成0,避免梯度的叠加效应
  77. outputs = net(images.to(device)) # output_shape: [batch_size, num_classes]
  78. pred_class = torch.max(outputs, dim=1)[1] # torch.max 返回值是一个tuple,第一个元素是max值,第二个元素是max值的索引。
  79. acc_num += torch.eq(pred_class, labels.to(device)).sum() # 是一个比较操作,它会将预测的类别(pred_class)和标签(labels)进行逐元素的比较,返回一个布尔类型的张量,表示对应位置上两个值是否相等。
  80. # sum() 是对布尔类型的张量进行求和操作,将所有为 True 的元素相加,得到一个标量值,表示预测正确的样本数量。
  81. loss = loss_function(outputs, labels.to(device)) # 求损失
  82. loss.backward() # 自动求导
  83. optimizer.step() # 梯度下降
  84. # print statistics
  85. train_acc = acc_num.item() / sample_num
  86. # .desc是进度条tqdm中的成员变量,作用是描述信息
  87. train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
  88. # validate
  89. net.eval() #不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;
  90. acc_num = 0.0 # accumulate accurate number per epoch
  91. with torch.no_grad():
  92. for val_data in validate_loader:
  93. val_images, val_labels = val_data
  94. outputs = net(val_images.to(device))
  95. predict_y = torch.max(outputs, dim=1)[1]
  96. acc_num += torch.eq(predict_y, val_labels.to(device)).sum().item()
  97. val_accurate = acc_num / val_num
  98. print('[epoch %d] train_loss: %.3f train_acc: %.3f val_accuracy: %.3f' % (epoch + 1, loss, train_acc, val_accurate))
  99. # 判断当前验证集的准确率是否是最大的,如果是,则更新之前保存的权重
  100. if val_accurate > best_acc:
  101. best_acc = val_accurate
  102. torch.save(net.state_dict(), os.path.join(save_path, "AlexNet.pth") ) # state_dict()返回一个包含了模型所有参数(权重和偏置)的字典。这个字典中的键是参数的名称,而对应的值则是该参数的张量。
  103. # 每次迭代后清空这些指标,重新计算
  104. train_acc = 0.0
  105. val_accurate = 0.0
  106. print('Finished Training')
  107. # if __name__ == '__main__':
  108. # main()
  109. main()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/151553
推荐阅读
相关标签
  

闽ICP备14008679号