当前位置:   article > 正文

第50步 深度学习图像识别:Data-efficient Image Transformers建模(Pytorch)

data-efficient image transformer

基于WIN10的64位系统演示

一、写在前面

(1)Data-efficient Image Transformers

Data-efficient Image Transformers (DeiT)是一种用于图像分类的新型模型,由Facebook AI在2020年底提出。这种方法基于视觉Transformer,通过训练策略的改进,使得模型能在少量数据下达到更高的性能。

在许多情况下,Transformer模型需要大量的数据才能得到好的结果。然而,这在某些场景下是不可能的,例如在只有少量标注数据的情况下。DeiT方法通过在训练过程中使用知识蒸馏,解决了这个问题。知识蒸馏是一种让小型模型学习大型模型行为的技术。

DeiT中的关键技术之一是使用学生模型预测教师模型的类别分布,而不仅仅是硬标签(原始数据集中的类别标签)。这样做的好处是,学生模型可以从教师模型的软标签(类别概率分布)中学习更多的信息。另外,DeiT还引入了一种新的训练方法,称为“硬标签蒸馏”,这种方法更进一步提高了模型的性能。通过这种方法,即使在ImageNet这样的大规模数据集上,DeiT也可以与更复杂的卷积神经网络(如ResNet和EfficientNet)相媲美或者超越,同时还使用了更少的计算资源。

(2)Data-efficient Image Transformers的预训练版本

本文继续使用Facebook的高级深度学习框架PyTorchImageModels (timm)。该库提供了多种预训练的模型,太多了,我还是给网址吧:

https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/deit.py

从第155行到第249行都是:

 

比如“deit_tiny_patch16_224”

"deit": 这是模型类型的缩写,代表 "Data-efficient Image Transformers"。

"tiny": 这个词说明了模型的大小。在此情况下,"tiny"意味着这是一个更小、计算成本更低的模型版本。相对的,还有"small","base"等不同规模的模型。

"patch16": 这是指在模型的输入阶段,原始图像被分割成大小为16x16像素的小方块(也被称为patch)进行处理。

"224": 这是指模型接受的输入图像的尺寸是224x224像素。这是在计算机视觉领域常用的图像尺寸。

二、Data-efficient Image Transformers迁移学习代码实战

我们继续胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)导入包

  1. import copy
  2. import torch
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from torchvision import models
  6. from torch.utils.data import DataLoader
  7. from torch import optim, nn
  8. from torch.optim import lr_scheduler
  9. import os
  10. import matplotlib.pyplot as plt
  11. import warnings
  12. import numpy as np
  13. warnings.filterwarnings("ignore")
  14. plt.rcParams['font.sans-serif'] = ['SimHei']
  15. plt.rcParams['axes.unicode_minus'] = False
  16. # 设置GPU
  17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(b)导入数据集

  1. import torch
  2. from torchvision import datasets, transforms
  3. import os
  4. # 数据集路径
  5. data_dir = "./MTB"
  6. # 图像的大小
  7. img_height = 100
  8. img_width = 100
  9. # 数据预处理
  10. data_transforms = {
  11. 'train': transforms.Compose([
  12. transforms.RandomResizedCrop(img_height),
  13. transforms.RandomHorizontalFlip(),
  14. transforms.RandomVerticalFlip(),
  15. transforms.RandomRotation(0.2),
  16. transforms.ToTensor(),
  17. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  18. ]),
  19. 'val': transforms.Compose([
  20. transforms.Resize((img_height, img_width)),
  21. transforms.ToTensor(),
  22. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  23. ]),
  24. }
  25. # 加载数据集
  26. full_dataset = datasets.ImageFolder(data_dir)
  27. # 获取数据集的大小
  28. full_size = len(full_dataset)
  29. train_size = int(0.7 * full_size) # 假设训练集占80%
  30. val_size = full_size - train_size # 验证集的大小
  31. # 随机分割数据集
  32. torch.manual_seed(0) # 设置随机种子以确保结果可重复
  33. train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
  34. # 将数据增强应用到训练集
  35. train_dataset.dataset.transform = data_transforms['train']
  36. # 创建数据加载器
  37. batch_size = 32
  38. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  39. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  40. dataloaders = {'train': train_dataloader, 'val': val_dataloader}
  41. dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
  42. class_names = full_dataset.classes

(c)导入Data-efficient Image Transformers

  1. # 导入所需的库
  2. import torch.nn as nn
  3. import timm
  4. # 定义Data-efficient Image Transformers模型
  5. model = timm.create_model('deit_base_patch16_224', pretrained=True) # 你可以选择适合你需求的DeiT版本,这里以deit_base_patch16_224为例
  6. num_ftrs = model.head.in_features
  7. # 根据分类任务修改最后一层
  8. model.head = nn.Linear(num_ftrs, len(class_names))
  9. # 将模型移至指定设备
  10. model = model.to(device)
  11. # 打印模型摘要
  12. print(model)

(d)编译模型

  1. # 定义损失函数
  2. criterion = nn.CrossEntropyLoss()
  3. # 定义优化器
  4. optimizer = optim.Adam(model.parameters())
  5. # 定义学习率调度器
  6. exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
  7. # 开始训练模型
  8. num_epochs = 10
  9. best_model_wts = copy.deepcopy(model.state_dict())
  10. best_acc = 0.0
  11. # 初始化记录器
  12. train_loss_history = []
  13. train_acc_history = []
  14. val_loss_history = []
  15. val_acc_history = []
  16. for epoch in range(num_epochs):
  17. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  18. print('-' * 10)
  19. # 每个epoch都有一个训练和验证阶段
  20. for phase in ['train', 'val']:
  21. if phase == 'train':
  22. model.train() # Set model to training mode
  23. else:
  24. model.eval() # Set model to evaluate mode
  25. running_loss = 0.0
  26. running_corrects = 0
  27. # 遍历数据
  28. for inputs, labels in dataloaders[phase]:
  29. inputs = inputs.to(device)
  30. labels = labels.to(device)
  31. # 零参数梯度
  32. optimizer.zero_grad()
  33. # 前向
  34. with torch.set_grad_enabled(phase == 'train'):
  35. outputs = model(inputs)
  36. _, preds = torch.max(outputs, 1)
  37. loss = criterion(outputs, labels)
  38. # 只在训练模式下进行反向和优化
  39. if phase == 'train':
  40. loss.backward()
  41. optimizer.step()
  42. # 统计
  43. running_loss += loss.item() * inputs.size(0)
  44. running_corrects += torch.sum(preds == labels.data)
  45. epoch_loss = running_loss / dataset_sizes[phase]
  46. epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()
  47. # 记录每个epoch的loss和accuracy
  48. if phase == 'train':
  49. train_loss_history.append(epoch_loss)
  50. train_acc_history.append(epoch_acc)
  51. else:
  52. val_loss_history.append(epoch_loss)
  53. val_acc_history.append(epoch_acc)
  54. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
  55. # 深拷贝模型
  56. if phase == 'val' and epoch_acc > best_acc:
  57. best_acc = epoch_acc
  58. best_model_wts = copy.deepcopy(model.state_dict())
  59. print()
  60. print('Best val Acc: {:4f}'.format(best_acc))

(e)Accuracy和Loss可视化

  1. epoch = range(1, len(train_loss_history)+1)
  2. fig, ax = plt.subplots(1, 2, figsize=(10,4))
  3. ax[0].plot(epoch, train_loss_history, label='Train loss')
  4. ax[0].plot(epoch, val_loss_history, label='Validation loss')
  5. ax[0].set_xlabel('Epochs')
  6. ax[0].set_ylabel('Loss')
  7. ax[0].legend()
  8. ax[1].plot(epoch, train_acc_history, label='Train acc')
  9. ax[1].plot(epoch, val_acc_history, label='Validation acc')
  10. ax[1].set_xlabel('Epochs')
  11. ax[1].set_ylabel('Accuracy')
  12. ax[1].legend()
  13. #plt.savefig("loss-acc.pdf", dpi=300,format="pdf")

观察模型训练情况:

蓝色为训练集,橙色为验证集。

(f)混淆矩阵可视化以及模型参数

  1. from sklearn.metrics import classification_report, confusion_matrix
  2. import math
  3. import pandas as pd
  4. import numpy as np
  5. import seaborn as sns
  6. from matplotlib.pyplot import imshow
  7. # 定义一个绘制混淆矩阵图的函数
  8. def plot_cm(labels, predictions):
  9. # 生成混淆矩阵
  10. conf_numpy = confusion_matrix(labels, predictions)
  11. # 将矩阵转化为 DataFrame
  12. conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)
  13. plt.figure(figsize=(8,7))
  14. sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
  15. plt.title('Confusion matrix',fontsize=15)
  16. plt.ylabel('Actual value',fontsize=14)
  17. plt.xlabel('Predictive value',fontsize=14)
  18. def evaluate_model(model, dataloader, device):
  19. model.eval() # 设置模型为评估模式
  20. true_labels = []
  21. pred_labels = []
  22. # 遍历数据
  23. for inputs, labels in dataloader:
  24. inputs = inputs.to(device)
  25. labels = labels.to(device)
  26. # 前向
  27. with torch.no_grad():
  28. outputs = model(inputs)
  29. _, preds = torch.max(outputs, 1)
  30. true_labels.extend(labels.cpu().numpy())
  31. pred_labels.extend(preds.cpu().numpy())
  32. return true_labels, pred_labels
  33. # 获取预测和真实标签
  34. true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)
  35. # 计算混淆矩阵
  36. cm_val = confusion_matrix(true_labels, pred_labels)
  37. a_val = cm_val[0,0]
  38. b_val = cm_val[0,1]
  39. c_val = cm_val[1,0]
  40. d_val = cm_val[1,1]
  41. # 计算各种性能指标
  42. acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val) # 准确率
  43. error_rate_val = 1 - acc_val # 错误率
  44. sen_val = d_val/(d_val+c_val) # 灵敏度
  45. sep_val = a_val/(a_val+b_val) # 特异度
  46. precision_val = d_val/(b_val+d_val) # 精确度
  47. F1_val = (2*precision_val*sen_val)/(precision_val+sen_val) # F1值
  48. MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val))) # 马修斯相关系数
  49. # 打印出性能指标
  50. print("验证集的灵敏度为:", sen_val,
  51. "验证集的特异度为:", sep_val,
  52. "验证集的准确率为:", acc_val,
  53. "验证集的错误率为:", error_rate_val,
  54. "验证集的精确度为:", precision_val,
  55. "验证集的F1为:", F1_val,
  56. "验证集的MCC为:", MCC_val)
  57. # 绘制混淆矩阵
  58. plot_cm(true_labels, pred_labels)
  59. # 获取预测和真实标签
  60. train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
  61. # 计算混淆矩阵
  62. cm_train = confusion_matrix(train_true_labels, train_pred_labels)
  63. a_train = cm_train[0,0]
  64. b_train = cm_train[0,1]
  65. c_train = cm_train[1,0]
  66. d_train = cm_train[1,1]
  67. acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
  68. error_rate_train = 1 - acc_train
  69. sen_train = d_train/(d_train+c_train)
  70. sep_train = a_train/(a_train+b_train)
  71. precision_train = d_train/(b_train+d_train)
  72. F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
  73. MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train)))
  74. print("训练集的灵敏度为:",sen_train,
  75. "训练集的特异度为:",sep_train,
  76. "训练集的准确率为:",acc_train,
  77. "训练集的错误率为:",error_rate_train,
  78. "训练集的精确度为:",precision_train,
  79. "训练集的F1为:",F1_train,
  80. "训练集的MCC为:",MCC_train)
  81. # 绘制混淆矩阵
  82. plot_cm(train_true_labels, train_pred_labels)

效果不错:

(g)AUC曲线绘制

  1. from sklearn import metrics
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from matplotlib.pyplot import imshow
  5. from sklearn.metrics import classification_report, confusion_matrix
  6. import seaborn as sns
  7. import pandas as pd
  8. import math
  9. def plot_roc(name, labels, predictions, **kwargs):
  10. fp, tp, _ = metrics.roc_curve(labels, predictions)
  11. plt.plot(fp, tp, label=name, linewidth=2, **kwargs)
  12. plt.plot([0, 1], [0, 1], color='orange', linestyle='--')
  13. plt.xlabel('False positives rate')
  14. plt.ylabel('True positives rate')
  15. ax = plt.gca()
  16. ax.set_aspect('equal')
  17. # 确保模型处于评估模式
  18. model.eval()
  19. train_ds = dataloaders['train']
  20. val_ds = dataloaders['val']
  21. val_pre_auc = []
  22. val_label_auc = []
  23. for images, labels in val_ds:
  24. for image, label in zip(images, labels):
  25. img_array = image.unsqueeze(0).to(device) # 在第0维增加一个维度并将图像转移到适当的设备上
  26. prediction_auc = model(img_array) # 使用模型进行预测
  27. val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])
  28. val_label_auc.append(label.item()) # 使用Tensor.item()获取Tensor的值
  29. auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)
  30. train_pre_auc = []
  31. train_label_auc = []
  32. for images, labels in train_ds:
  33. for image, label in zip(images, labels):
  34. img_array_train = image.unsqueeze(0).to(device)
  35. prediction_auc = model(img_array_train)
  36. train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1]) # 输出概率而不是标签!
  37. train_label_auc.append(label.item())
  38. auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)
  39. plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc , val_pre_auc , color="red", linestyle='--')
  40. plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
  41. plt.legend(loc='lower right')
  42. #plt.savefig("roc.pdf", dpi=300,format="pdf")
  43. print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

 ROC曲线如下:

 很优秀的ROC曲线!

三、写在最后

运算量和消耗的计算资源还是大,在这个数据集上跑出来的性能比ViT模型要好一些,说明优化策略还是起到效果的。

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/358067
推荐阅读
相关标签
  

闽ICP备14008679号