当前位置:   article > 正文

第54步 深度学习图像识别:MLP-Mixer建模(Pytorch)

mlp-mixer

基于WIN10的64位系统演示

一、写在前面

(1)MLP-Mixer

MLP-Mixer(Multilayer Perceptron Mixer)是Google在2021年提出的一种新型的视觉模型结构。它的主要特点是完全使用多层感知机(MLP)来处理图像,而不是使用常见的卷积(Convolution)或者自注意力(Self-Attention)机制。

MLP-Mixer的结构主要包括两种类型的层:Token Mixing层和Channel Mixing层。在Token Mixing层中,模型会将图像分割成若干个patch(类似于像素块),然后对这些patch进行处理。在Channel Mixing层中,模型会对每个patch的通道进行处理。这两种类型的层交替堆叠,形成了最终的模型结构。

MLP-Mixer的设计目标是探索除卷积和自注意力之外的其他可能的模型结构,以期在保持性能的同时,降低模型的复杂性和计算成本。实验结果显示,MLP-Mixer在一些图像分类任务上的性能可以与ResNet和Transformer等主流模型相媲美。

然而,需要注意的是,虽然MLP-Mixer在某些方面展现出了很好的性能,但它并不意味着会替代卷积或者自注意力模型。实际上,每种模型都有其适用的场景和优势,MLP-Mixer提供了一个新的视角和工具,供我们处理视觉任务。

(2)MLP-Mixer的码源

本文使用 mlp-mixer-pytorch 库来实现MLP-Mixer。

当然,得先安装这个库:

(a)首先,打开Anaconda Prompt。在开始菜单中找到它,或者直接在搜索栏中输入"Anaconda Prompt"。在打开的Anaconda Prompt中,如果你想在一个特定的环境中安装mlp_mixer_pytorch,你需要先激活这个环境。假设你的环境名为myenv,你可以使用以下命令来激活这个环境:

conda activate myenv

(b)接下来,使用pip来安装mlp_mixer_pytorch库。在Anaconda Prompt中输入以下命令并按回车键:

pip install mlp-mixer-pytorch

二、MLP-Mixer迁移学习代码实战

我们继续胸片的数据集:肺结核病人和健康人的胸片的识别。其中,肺结核病人700张,健康人900张,分别存入单独的文件夹中。

(a)导入包

  1. import copy
  2. import torch
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from torchvision import models
  6. from torch.utils.data import DataLoader
  7. from torch import optim, nn
  8. from torch.optim import lr_scheduler
  9. import os
  10. import matplotlib.pyplot as plt
  11. import warnings
  12. import numpy as np
  13. warnings.filterwarnings("ignore")
  14. plt.rcParams['font.sans-serif'] = ['SimHei']
  15. plt.rcParams['axes.unicode_minus'] = False
  16. # 设置GPU
  17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

(b)导入数据集

  1. import torch
  2. from torchvision import datasets, transforms
  3. import os
  4. # 数据集路径
  5. data_dir = "./MTB"
  6. # 图像的大小
  7. img_height = 256
  8. img_width = 256
  9. # 数据预处理
  10. data_transforms = {
  11. 'train': transforms.Compose([
  12. transforms.RandomResizedCrop(img_height),
  13. transforms.RandomHorizontalFlip(),
  14. transforms.RandomVerticalFlip(),
  15. transforms.RandomRotation(0.2),
  16. transforms.ToTensor(),
  17. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  18. ]),
  19. 'val': transforms.Compose([
  20. transforms.Resize((img_height, img_width)),
  21. transforms.ToTensor(),
  22. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  23. ]),
  24. }
  25. # 加载数据集
  26. full_dataset = datasets.ImageFolder(data_dir)
  27. # 获取数据集的大小
  28. full_size = len(full_dataset)
  29. train_size = int(0.7 * full_size) # 假设训练集占80%
  30. val_size = full_size - train_size # 验证集的大小
  31. # 随机分割数据集
  32. torch.manual_seed(0) # 设置随机种子以确保结果可重复
  33. train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
  34. # 将数据增强应用到训练集
  35. train_dataset.dataset.transform = data_transforms['train']
  36. # 创建数据加载器
  37. batch_size = 32
  38. train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  39. val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  40. dataloaders = {'train': train_dataloader, 'val': val_dataloader}
  41. dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
  42. class_names = full_dataset.classes

(c)导入MLPMixer

  1. from mlp_mixer_pytorch import MLPMixer
  2. num_classes = len(class_names) # 根据数据集的类别数量来设置模型的输出类别数量
  3. # 构建MLP-Mixer模型
  4. model = MLPMixer(
  5. image_size = img_height, # 图像的高和宽
  6. channels = 3, # 图像的通道数
  7. patch_size = 16, # MLP-Mixer的patch大小
  8. dim = 512, # MLP-Mixer的维度
  9. depth = 12, # MLP-Mixer的深度
  10. num_classes = num_classes # 输出类别数量
  11. )
  12. # 将模型移动到GPU
  13. model = model.to(device)
  14. # 打印模型摘要
  15. print(model)

说明:mlp-mixer-pytorch库的主要功能就是提供了一个MLP-Mixer的类,可以通过实例化这个类来创建一个MLP-Mixer模型。在创建模型时,可以通过参数来设置图像的大小、通道数、patch的大小、模型的维度、深度以及输出类别的数量等。

需要注意的是,mlp-mixer-pytorch库提供的MLP-Mixer模型默认是随机初始化的,也就是说并没有加载预训练权重。如果你有MLP-Mixer的预训练权重,可以在创建模型后加载。

(d)编译模型

  1. # 定义损失函数
  2. criterion = nn.CrossEntropyLoss()
  3. # 定义优化器
  4. optimizer = optim.Adam(model.parameters())
  5. # 定义学习率调度器
  6. exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
  7. # 开始训练模型
  8. num_epochs = 20
  9. best_model_wts = copy.deepcopy(model.state_dict())
  10. best_acc = 0.0
  11. # 初始化记录器
  12. train_loss_history = []
  13. train_acc_history = []
  14. val_loss_history = []
  15. val_acc_history = []
  16. for epoch in range(num_epochs):
  17. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  18. print('-' * 10)
  19. # 每个epoch都有一个训练和验证阶段
  20. for phase in ['train', 'val']:
  21. if phase == 'train':
  22. model.train() # Set model to training mode
  23. else:
  24. model.eval() # Set model to evaluate mode
  25. running_loss = 0.0
  26. running_corrects = 0
  27. # 遍历数据
  28. for inputs, labels in dataloaders[phase]:
  29. inputs = inputs.to(device)
  30. labels = labels.to(device)
  31. # 零参数梯度
  32. optimizer.zero_grad()
  33. # 前向
  34. with torch.set_grad_enabled(phase == 'train'):
  35. outputs = model(inputs)
  36. _, preds = torch.max(outputs, 1)
  37. loss = criterion(outputs, labels)
  38. # 只在训练模式下进行反向和优化
  39. if phase == 'train':
  40. loss.backward()
  41. optimizer.step()
  42. # 统计
  43. running_loss += loss.item() * inputs.size(0)
  44. running_corrects += torch.sum(preds == labels.data)
  45. epoch_loss = running_loss / dataset_sizes[phase]
  46. epoch_acc = (running_corrects.double() / dataset_sizes[phase]).item()
  47. # 记录每个epoch的loss和accuracy
  48. if phase == 'train':
  49. train_loss_history.append(epoch_loss)
  50. train_acc_history.append(epoch_acc)
  51. else:
  52. val_loss_history.append(epoch_loss)
  53. val_acc_history.append(epoch_acc)
  54. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
  55. # 深拷贝模型
  56. if phase == 'val' and epoch_acc > best_acc:
  57. best_acc = epoch_acc
  58. best_model_wts = copy.deepcopy(model.state_dict())
  59. print()
  60. print('Best val Acc: {:4f}'.format(best_acc))
  61. # 加载最佳模型权重
  62. #model.load_state_dict(best_model_wts)
  63. #torch.save(model, 'shufflenet_best_model.pth')
  64. #print("The trained model has been saved.")

(e)Accuracy和Loss可视化

  1. epoch = range(1, len(train_loss_history)+1)
  2. fig, ax = plt.subplots(1, 2, figsize=(10,4))
  3. ax[0].plot(epoch, train_loss_history, label='Train loss')
  4. ax[0].plot(epoch, val_loss_history, label='Validation loss')
  5. ax[0].set_xlabel('Epochs')
  6. ax[0].set_ylabel('Loss')
  7. ax[0].legend()
  8. ax[1].plot(epoch, train_acc_history, label='Train acc')
  9. ax[1].plot(epoch, val_acc_history, label='Validation acc')
  10. ax[1].set_xlabel('Epochs')
  11. ax[1].set_ylabel('Accuracy')
  12. ax[1].legend()
  13. #plt.savefig("loss-acc.pdf", dpi=300,format="pdf")

观察模型训练情况:

 蓝色为训练集,橙色为验证集。

(f)混淆矩阵可视化以及模型参数

  1. from sklearn.metrics import classification_report, confusion_matrix
  2. import math
  3. import pandas as pd
  4. import numpy as np
  5. import seaborn as sns
  6. from matplotlib.pyplot import imshow
  7. # 定义一个绘制混淆矩阵图的函数
  8. def plot_cm(labels, predictions):
  9. # 生成混淆矩阵
  10. conf_numpy = confusion_matrix(labels, predictions)
  11. # 将矩阵转化为 DataFrame
  12. conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)
  13. plt.figure(figsize=(8,7))
  14. sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
  15. plt.title('Confusion matrix',fontsize=15)
  16. plt.ylabel('Actual value',fontsize=14)
  17. plt.xlabel('Predictive value',fontsize=14)
  18. def evaluate_model(model, dataloader, device):
  19. model.eval() # 设置模型为评估模式
  20. true_labels = []
  21. pred_labels = []
  22. # 遍历数据
  23. for inputs, labels in dataloader:
  24. inputs = inputs.to(device)
  25. labels = labels.to(device)
  26. # 前向
  27. with torch.no_grad():
  28. outputs = model(inputs)
  29. _, preds = torch.max(outputs, 1)
  30. true_labels.extend(labels.cpu().numpy())
  31. pred_labels.extend(preds.cpu().numpy())
  32. return true_labels, pred_labels
  33. # 获取预测和真实标签
  34. true_labels, pred_labels = evaluate_model(model, dataloaders['val'], device)
  35. # 计算混淆矩阵
  36. cm_val = confusion_matrix(true_labels, pred_labels)
  37. a_val = cm_val[0,0]
  38. b_val = cm_val[0,1]
  39. c_val = cm_val[1,0]
  40. d_val = cm_val[1,1]
  41. # 计算各种性能指标
  42. acc_val = (a_val+d_val)/(a_val+b_val+c_val+d_val) # 准确率
  43. error_rate_val = 1 - acc_val # 错误率
  44. sen_val = d_val/(d_val+c_val) # 灵敏度
  45. sep_val = a_val/(a_val+b_val) # 特异度
  46. precision_val = d_val/(b_val+d_val) # 精确度
  47. F1_val = (2*precision_val*sen_val)/(precision_val+sen_val) # F1值
  48. MCC_val = (d_val*a_val-b_val*c_val) / (np.sqrt((d_val+b_val)*(d_val+c_val)*(a_val+b_val)*(a_val+c_val))) # 马修斯相关系数
  49. # 打印出性能指标
  50. print("验证集的灵敏度为:", sen_val,
  51. "验证集的特异度为:", sep_val,
  52. "验证集的准确率为:", acc_val,
  53. "验证集的错误率为:", error_rate_val,
  54. "验证集的精确度为:", precision_val,
  55. "验证集的F1为:", F1_val,
  56. "验证集的MCC为:", MCC_val)
  57. # 绘制混淆矩阵
  58. plot_cm(true_labels, pred_labels)
  59. # 获取预测和真实标签
  60. train_true_labels, train_pred_labels = evaluate_model(model, dataloaders['train'], device)
  61. # 计算混淆矩阵
  62. cm_train = confusion_matrix(train_true_labels, train_pred_labels)
  63. a_train = cm_train[0,0]
  64. b_train = cm_train[0,1]
  65. c_train = cm_train[1,0]
  66. d_train = cm_train[1,1]
  67. acc_train = (a_train+d_train)/(a_train+b_train+c_train+d_train)
  68. error_rate_train = 1 - acc_train
  69. sen_train = d_train/(d_train+c_train)
  70. sep_train = a_train/(a_train+b_train)
  71. precision_train = d_train/(b_train+d_train)
  72. F1_train = (2*precision_train*sen_train)/(precision_train+sen_train)
  73. MCC_train = (d_train*a_train-b_train*c_train) / (math.sqrt((d_train+b_train)*(d_train+c_train)*(a_train+b_train)*(a_train+c_train)))
  74. print("训练集的灵敏度为:",sen_train,
  75. "训练集的特异度为:",sep_train,
  76. "训练集的准确率为:",acc_train,
  77. "训练集的错误率为:",error_rate_train,
  78. "训练集的精确度为:",precision_train,
  79. "训练集的F1为:",F1_train,
  80. "训练集的MCC为:",MCC_train)
  81. # 绘制混淆矩阵
  82. plot_cm(train_true_labels, train_pred_labels)

效果不错:

 (g)AUC曲线绘制

  1. from sklearn import metrics
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from matplotlib.pyplot import imshow
  5. from sklearn.metrics import classification_report, confusion_matrix
  6. import seaborn as sns
  7. import pandas as pd
  8. import math
  9. def plot_roc(name, labels, predictions, **kwargs):
  10. fp, tp, _ = metrics.roc_curve(labels, predictions)
  11. plt.plot(fp, tp, label=name, linewidth=2, **kwargs)
  12. plt.plot([0, 1], [0, 1], color='orange', linestyle='--')
  13. plt.xlabel('False positives rate')
  14. plt.ylabel('True positives rate')
  15. ax = plt.gca()
  16. ax.set_aspect('equal')
  17. # 确保模型处于评估模式
  18. model.eval()
  19. train_ds = dataloaders['train']
  20. val_ds = dataloaders['val']
  21. val_pre_auc = []
  22. val_label_auc = []
  23. for images, labels in val_ds:
  24. for image, label in zip(images, labels):
  25. img_array = image.unsqueeze(0).to(device) # 在第0维增加一个维度并将图像转移到适当的设备上
  26. prediction_auc = model(img_array) # 使用模型进行预测
  27. val_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1])
  28. val_label_auc.append(label.item()) # 使用Tensor.item()获取Tensor的值
  29. auc_score_val = metrics.roc_auc_score(val_label_auc, val_pre_auc)
  30. train_pre_auc = []
  31. train_label_auc = []
  32. for images, labels in train_ds:
  33. for image, label in zip(images, labels):
  34. img_array_train = image.unsqueeze(0).to(device)
  35. prediction_auc = model(img_array_train)
  36. train_pre_auc.append(prediction_auc.detach().cpu().numpy()[:,1]) # 输出概率而不是标签!
  37. train_label_auc.append(label.item())
  38. auc_score_train = metrics.roc_auc_score(train_label_auc, train_pre_auc)
  39. plot_roc('validation AUC: {0:.4f}'.format(auc_score_val), val_label_auc , val_pre_auc , color="red", linestyle='--')
  40. plot_roc('training AUC: {0:.4f}'.format(auc_score_train), train_label_auc, train_pre_auc, color="blue", linestyle='--')
  41. plt.legend(loc='lower right')
  42. #plt.savefig("roc.pdf", dpi=300,format="pdf")
  43. print("训练集的AUC值为:",auc_score_train, "验证集的AUC值为:",auc_score_val)

ROC曲线如下:

 这个ROC曲线也是不错的!全部大于95%!

三、写在最后

截至目前,图像分类领域基本就是CNN、Transformer和MLP三足鼎立了。孰优孰劣,还不好说,中庸之道那就是各有千秋。他们之间的两两组合或者一起融合的话,效果又会如何?

四、数据

链接:https://pan.baidu.com/s/15vSVhz1rQBtqNkNp2GQyVw?pwd=x3jf

提取码:x3jf

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/248451
推荐阅读
相关标签
  

闽ICP备14008679号