当前位置:   article > 正文

Vision Transformer(ViT)模型及其在图像分类任务中的应用_基于vit模型实现图像分类流程

基于vit模型实现图像分类流程

一、实验目的

1.简要介绍 Vision Transformer(ViT)模型及其在图像分类任务中的应用。

         基本概念

  • Vision Transformer (ViT) 是一种基于Transformer架构的模型,最初在自然语言处理领域获得成功,后被引入到计算机视觉领域。
  • 核心思想:ViT将图像分割成一系列小块(称为patches),这些小块被处理成序列数据,类似于在NLP中处理单词序列。
  • 自注意力机制:ViT通过自注意力机制学习图像中不同部分之间的关系,这与传统的CNN使用卷积核在空间上聚焦局部区域形成鲜明对比。

与CNN的演化

  • 从CNN到ViT:传统的卷积神经网络(CNN)依赖于局部卷积操作和池化来处理图像,这导致它们在处理图像的全局信息时存在局限性。
  • 全局信息处理:ViT通过自注意力机制克服了这一限制,能够在整个图像序列上建立长距离依赖关系。
  • 高效性:与CNN相比,ViT在处理高分辨率图像时更为高效,因为它避免了逐层卷积操作中的重复计算

2. 阐述实验目的和重要性。

         Vision Transformer在图像分类任务中的应用标志着从局部聚焦的CNN到全局信息处理的转变,这在图像理解和分析中打开了新的可能性。由于其独特的结构和处理方式,ViT在处理复杂和多样化的图像数据方面显示出显著的优势。

二、实验设计:

1. 实验的整体框架和步骤

  • 步骤一:数据准备
    • 选择猫狗图像数据集:使用专门的猫狗图像数据集,如Kaggle上的猫狗分类挑战数据集。
    • 数据预处理:包括调整图像大小、归一化等。
  • 步骤二:模型设计与实现
    • 设计Vision Transformer模型:根据ViT的架构设置模型参数,调整以适应猫狗图像特性。
    • 编写实验代码:实现模型训练、验证和测试的过程。
  • 步骤三:模型训练
    • 在训练集上训练模型:调整超参数以优化猫狗图像分类的性能。
    • 在验证集上进行性能评估。
  • 步骤四:模型测试与评估
    • 在测试集上评估模型性能。
    • 分析结果:分析模型在猫狗分类任务中的表现。

2. Vision Transformer模型的结构和原理

  • 针对猫狗分类调整模型结构
    • 输入处理:根据猫狗图像的特点调整patch大小。
    • 位置嵌入和Transformer编码器:同标准ViT模型。
    • 分类头:设计为二分类输出。
  • 模型原理
    • 自注意力机制:学习猫和狗图像中关键特征的相关性。
    • 层序连接和分类头:适用于猫狗分类任务。

3. 图像数据的预处理与划分

  • 预处理
    • 调整大小:统一图像尺寸,如将所有图像调整为256x256像素。
    • 归一化:将像素值归一化到0-1范围。
    • 数据增强:应用如随机裁剪、旋转等技术增加数据多样性。
  • 数据划分
    • 训练集:选取数据集的大部分(例如70%)用于模型训练。
    • 验证集:约15%的数据用于模型的性能调优。
    • 测试集:剩余的15%用于最终的性能评估。

三、实验过程:

1. 软件和硬件环境

  • 软件环境
    • 编程语言:Python。
    • 深度学习框架:PyTorc

2. 模型训练过程

  • 超参数的选择
    • 学习率:开始时可以设置较大的学习率(例如0.001),并在训练过程中逐步降低。
    • 批处理大小:根据GPU内存容量调整(例如32或64)。
    • Epoch数量:设置足够多的epoch以确保充分训练(例如100个epoch)。
  • 优化器的设置
    • 优化器选择:通常使用Adam优化器,因其在深度学习模型中表现良好。
    • 正则化技术:如权重衰减(L2正则化)以防止过拟合。

3. 训练和调优策略

  • 损失函数的选择
    • 对于猫狗二分类任务,使用二元交叉熵损失(Binary Cross-Entropy Loss)。
  • 训练策略
    • 早期停止:监控验证集的损失,如果在连续几个epoch中没有明显改善,则提前终止训练。
    • 学习率调整:使用学习率衰减策略,如在验证集损失停止改善时降低学习率。
    • 数据增强:在训练过程中应用数据增强技术,增加模型的泛化能力。
  • 调优过程
    • 超参数调优:通过实验不同的超参数组合来找到最佳配置。
    • 性能监控:定期检查模型在训练集和验证集上的性能,确保模型正在学习并且没有过拟合。
    • 结果记录:记录不同超参数设置下的训练结果,以便进行比较和分析。

四、实验结果:

1. 验证集上的模型表现

Loss-Acc图:

ROC曲线和AUC值:

  • 性能指标
    • 准确率:表示模型正确分类图像的比例。
    • 召回率:特别针对每个类别(猫或狗),召回率显示了模型正确识别该类别图像的能力。
    • 精确度:精确度衡量的是在预测为特定类别的图像中,实际上属于该类别的比例。
    • F1分数:F1分数是精确度和召回率的调和平均,用于评估模型的整体性能。
  • 性能图表
    • 使用混淆矩阵来可视化模型在不同类别上的性能。
    • 绘制ROC曲线和AUC值,以评估模型在不同阈值下的分类能力。

2. 实验结果分析

  • 优点分析
    • 全局特征学习:Vision Transformer因其自注意力机制能够捕捉图像的全局特征,这可能导致在特定类型的图像上表现出色。
    • 泛化能力:如果模型在多种类型的猫狗图像(如不同品种、背景等)上表现良好,这表明它具有良好的泛化能力。
    • 高准确率:一个高准确率表明模型在大多数情况下能够正确分类图像。
  • 不足之处
    • 对特定特征的敏感性:如果模型在某些特定类型的图像(如特定背景或照明条件下的图像)上表现不佳,这可能暗示模型对于某些特征过于敏感。
    • 计算资源需求:ViT模型可能需要较多的计算资源,这在实际应用中可能是一个限制因素。
    • 训练时间:与某些传统CNN模型相比,ViT可能需要更长的训练时间,特别是在缺乏优化时。

五、结论:

实验的局限性和可改进之处

  • 数据集局限性:指出如果实验只用了特定类型的猫狗图像数据集,可能限制了模型泛化到更广泛场景的能力。
  • 计算资源需求:讨论实验中ViT模型相对较高的计算资源需求,以及这可能对实际应用造成的限制。
  • 训练时间和成本:指出模型训练所需的时间和成本,特别是在资源受限的环境中,这可能是一个重要考虑因素。
  • 改进建议
    • 数据多样性:增加数据集的多样性,包括不同的图像质量、光照条件、背景等,以进一步测试和提高模型的泛化能力。
    • 模型优化:探索模型架构和训练过程的优化,以减少资源消耗和提高训练效率。
    • 后续研究方向:建议未来研究可以探索将ViT与其他技术(如CNN)结合,以利用各自的优点,或开发更轻量级的ViT变体。

总体而言,这项实验展示了Vision Transformer在猫狗图像分类任务上的有效性和潜力,同时也揭示了其在数据和计算资源方面的一些局限性。未来的研究可以在这些发现的基础上进行,以实现更广泛的应用和更优的性能。

实验代码:

  1. import copy
  2. import torch
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from torchvision import models
  6. from torch.utils.data import DataLoader
  7. from torch import optim, nn
  8. from torch.optim import lr_scheduler
  9. import os
  10. import matplotlib.pyplot as plt
  11. import warnings
  12. import numpy as np
  13. warnings.filterwarnings("ignore")
  14. plt.rcParams['font.sans-serif'] = ['SimHei']
  15. plt.rcParams['axes.unicode_minus'] = False
  16. # 设置GPU
  17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  18. import torch
  19. from torchvision import datasets, transforms
  20. import os
  21. # 数据集路径
  22. data_dir = "E:\深度学习\Vision Transformer\cats_and_dogs_small"
  23. # 图像的大小
  24. img_height = 224
  25. img_width = 224
  26. # 数据预处理
  27. data_transforms = {
  28. 'train': transforms.Compose([
  29. transforms.RandomResizedCrop(img_height),
  30. transforms.RandomHorizontalFlip(),
  31. transforms.RandomVerticalFlip(),
  32. transforms.RandomRotation(0.2),
  33. transforms.ToTensor(),
  34. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  35. ]),
  36. 'val': transforms.Compose([
  37. transforms.Resize((img_height, img_width)),
  38. transforms.ToTensor(),
  39. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  40. ]),
  41. }
  42. # 加载数据集
  43. full_dataset = datasets.ImageFolder(data_dir)
  44. # 获取数据集的大小
  45. full_size = len(full_dataset)
  46. train_size = int(0.7 * full_size) # 假设训练集占80%
  47. val_size = full_size - train_size # 验证集的大小
  48. # 随机分割数据集
  49. torch.manual_seed(0) # 设置随机种子以确保结果可重复
  50. train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
  51. # 将数据增强应用到训练集
  52. train_dataset.dataset.transform = data_transforms['train']
  53. # 创建数据加载器
  54. batch_size = 32
  55. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
  56. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
  57. dataloaders = {'train': train_dataloader, 'val': val_dataloader}
  58. dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
  59. class_names = full_dataset.classes
  60. # 定义Vision Transformer模型
  61. import timm
  62. model = timm.create_model('vit_base_patch16_224',
  63. pretrained=True) # 你可以选择适合你需求的Vision Transformer版本,这里以vit_base_patch16_224为例
  64. num_ftrs = model.head.in_features
  65. # 根据分类任务修改最后一层
  66. model.head = nn.Linear(num_ftrs, len(class_names))
  67. model = model.to(device)
  68. # 打印模型摘要
  69. print(model)
  70. # 定义损失函数
  71. criterion = nn.CrossEntropyLoss()
  72. # 定义优化器
  73. optimizer = optim.Adam(model.parameters())
  74. # 定义学习率调度器
  75. exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
  76. # 开始训练模型
  77. num_epochs = 10
  78. best_model_wts = copy.deepcopy(model.state_dict())
  79. best_acc = 0.0
  80. # 初始化记录器
  81. train_loss_history = []
  82. train_acc_history = []
  83. val_loss_history = []
  84. val_acc_history = []
  85. for epoch in range(num_epochs):
  86. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  87. print('-' * 10)
  88. # 每个epoch都有一个训练和验证阶段
  89. for phase in ['train', 'val']:
  90. if phase == 'train':
  91. model.train() # Set model to training mode
  92. else:
  93. model.eval() # Set model to evaluate mode
  94. running_loss = 0.0
  95. running_corrects = 0
  96. # 遍历数据
  97. for inputs, labels in dataloaders[phase]:
  98. inputs = inputs.to(device)
  99. labels = labels.to(device)
  100. # 零参数梯度
  101. optimizer.zero_grad()
  102. # 前向
  103. with torch.set_grad_enabled(phase == 'train'):
  104. outputs = model(inputs)
  105. _, preds = torch.max(outputs, 1)
  106. loss = criterion(outputs, labels)
  107. # 只在训练模式下进行反向和优化
  108. if phase == 'train':
  109. loss.backward()
  110. optimizer.step()
  111. # 统计
  112. running_loss += loss.item() * inputs.size(0)
  113. running_corrects += torch.sum(preds == labels.data)
  114. epoch_loss = running_loss / dataset_sizes[phase]
  115. epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()
  116. # 记录每个epoch的loss和accuracy
  117. if phase == 'train':
  118. train_loss_history.append(epoch_loss)
  119. train_acc_history.append(epoch_acc)
  120. else:
  121. val_loss_history.append(epoch_loss)
  122. val_acc_history.append(epoch_acc)
  123. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
  124. # 深拷贝模型
  125. if phase == 'val' and epoch_acc > best_acc:
  126. best_acc = epoch_acc
  127. best_model_wts = copy.deepcopy(model.state_dict())
  128. print()
  129. print('Best val Acc: {:4f}'.format(best_acc))
  130. epoch = range(1, len(train_loss_history) + 1)
  131. fig, ax = plt.subplots(1, 2, figsize=(10, 4))
  132. ax[0].plot(epoch, train_loss_history, label='Train loss')
  133. ax[0].plot(epoch, val_loss_history, label='Validation loss')
  134. ax[0].set_xlabel('Epochs')
  135. ax[0].set_ylabel('Loss')
  136. ax[0].legend()
  137. ax[1].plot(epoch, train_acc_history, label='Train acc')
  138. ax[1].plot(epoch, val_acc_history, label='Validation acc')
  139. ax[1].set_xlabel('Epochs')
  140. ax[1].set_ylabel('Accuracy')
  141. ax[1].legend()
  142. plt.savefig("loss-acc.pdf", dpi=300,format="pdf")
  143. from sklearn.metrics import classification_report, confusion_matrix
  144. import math
  145. import pandas as pd
  146. import numpy as np
  147. import seaborn as sns
  148. from matplotlib.pyplot import imshow
  149. # 定义一个绘制混淆矩阵图的函数
  150. def plot_cm(labels, predictions):
  151. # 生成混淆矩阵
  152. conf_numpy = confusion_matrix(labels, predictions)
  153. # 将矩阵转化为 DataFrame
  154. conf_df = pd.DataFrame(conf_numpy, index=class_names, columns=class_names)
  155. plt.figure(figsize=(8, 7))
  156. sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
  157. plt.title('Confusion matrix', fontsize=15)
  158. plt.ylabel('Actual value', fontsize=14)
  159. plt.xlabel('Predictive value', fontsize=14)
  160. def evaluate_model(model, dataloader, device):
  161. model.eval() # 设置模型为评估模式
  162. true_labels = []
  163. pred_labels = []
  164. # 遍历数据
  165. for inputs, labels in dataloader:
  166. inputs = inputs.to(device)
  167. labels = labels.to(device)
  168. # 前向
  169. with torch.no_grad():
  170. outputs = model(inputs)
  171. _, preds = torch.max(outputs, 1)
  172. true_labels.extend(labels.cpu().numpy())
  173. pred_labels.extend(preds.cpu().numpy())
  174. return true_labels, pred_labels
  175. # 获取预测和真实标签
  176. true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)
  177. # 计算混淆矩阵
  178. cm_val = confusion_matrix(true_labels, pred_labels)
  179. a_val = cm_val[0, 0]
  180. b_val = cm_val[0, 1]
  181. c_val = cm_val[1, 0]
  182. d_val = cm_val[1, 1]
  183. # 计算各种性能指标
  184. acc_val = (a_val + d_val) / (a_val + b_val + c_val + d_val) # 准确率
  185. error_rate_val = 1 - acc_val # 错误率
  186. sen_val = d_val / (d_val + c_val) # 灵敏度
  187. sep_val = a_val / (a_val + b_val) # 特异度
  188. precision_val = d_val / (b_val + d_val) # 精确度
  189. F1_val = (2 * precision_val * sen_val) / (precision_val + sen_val) # F1
  190. MCC_val = (d_val * a_val - b_val * c_val) / (
  191. np.sqrt((d_val + b_val) * (d_val + c_val) * (a_val + b_val) * (a_val + c_val))) # 马修斯相关系数
  192. # 打印出性能指标
  193. print("验证集的灵敏度为:", sen_val,
  194. "验证集的特异度为:", sep_val,
  195. "验证集的准确率为:", acc_val,
  196. "验证集的错误率为:", error_rate_val,
  197. "验证集的精确度为:", precision_val,
  198. "验证集的F1为:", F1_val,
  199. "验证集的MCC为:", MCC_val)
  200. # 绘制混淆矩阵
  201. plot_cm(true_labels, pred_labels)
  202. # 获取预测和真实标签
  203. train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
  204. # 计算混淆矩阵
  205. cm_train = confusion_matrix(train_true_labels, train_pred_labels)
  206. a_train = cm_train[0, 0]
  207. b_train = cm_train[0, 1]
  208. c_train = cm_train[1, 0]
  209. d_train = cm_train[1, 1]
  210. acc_train = (a_train + d_train) / (a_train + b_train + c_train + d_train)
  211. error_rate_train = 1 - acc_train
  212. sen_train = d_train / (d_train + c_train)
  213. sep_train = a_train / (a_train + b_train)
  214. precision_train = d_train / (b_train + d_train)
  215. F1_train = (2 * precision_train * sen_train) / (precision_train + sen_train)
  216. MCC_train = (d_train * a_train - b_train * c_train) / (
  217. math.sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train)))
  218. print("训练集的灵敏度为:", sen_train,
  219. "训练集的特异度为:", sep_train,
  220. "训练集的准确率为:", acc_train,
  221. "训练集的错误率为:", error_rate_train,
  222. "训练集的精确度为:", precision_train,
  223. "训练集的F1为:", F1_train,
  224. "训练集的MCC为:", MCC_train)
  225. # 绘制混淆矩阵
  226. plot_cm(train_true_labels, train_pred_labels)
  227. from sklearn import metrics
  228. import numpy as np
  229. import matplotlib.pyplot as plt
  230. from matplotlib.pyplot import imshow
  231. from sklearn.metrics import classification_report, confusion_matrix
  232. import seaborn as sns
  233. import pandas as pd
  234. import math
  235. def plot_roc(name, labels, predictions, **kwargs):
  236. fp, tp, _ = metrics.roc_curve(labels, predictions)
  237. plt.plot(fp, tp, label=name, linewidth=2, **kwargs)
  238. plt.plot([0, 1], [0, 1], color='orange', linestyle='--')
  239. plt.xlabel('False positives rate')
  240. plt.ylabel('True positives rate')
  241. ax = plt.gca()
  242. ax.set_aspect('equal')
  243. # 确保模型处于评估模式
  244. model.eval()
  245. train_ds = dataloaders['train']
  246. val_ds = dataloaders['val']
  247. val_pre_auc = []
  248. val_label_auc = []
  249. for images, labels in val_ds:
  250. for image, label in zip(images, labels):
  251. img_array = image.unsqueeze(0).to(device) # 在第0维增加一个维度并将图像转移到适当的设备上
  252. prediction_auc = model(img_array) # 使用模型进行预测
  253. val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:, 1])
  254. val_label_auc.append(label.item()) # 使用Tensor.item()获取Tensor的值
  255. auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)
  256. train_pre_auc = []
  257. train_label_auc = []
  258. for images, labels in train_ds:
  259. for image, label in zip(images, labels):
  260. img_array_train = image.unsqueeze(0).to(device)
  261. prediction_auc = model(img_array_train)
  262. train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:, 1]) # 输出概率而不是标签!
  263. train_label_auc.append(label.item())
  264. auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)
  265. plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc, val_pre_auc, color="red", linestyle='--')
  266. plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
  267. plt.legend(loc='lower right')
  268. # plt.savefig("roc.pdf", dpi=300,format="pdf")
  269. print("训练集的AUC值为:", auc_score_train, "验证集的AUC值为:", auc_score_val)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/1019967
推荐阅读
相关标签
  

闽ICP备14008679号