赞
踩
本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演示如何使用pytorch版本的VIT图像分类模型实现分类任务。
通过本文你和学到:
1、如何构建VIT模型?
2、如何生成数据集?
3、如何使用Cutout数据增强?
4、如何使用Mixup数据增强。
5、如何实现训练和验证。
6、如何使用余弦退火调整学习率?
7、预测的两种写法。
这篇文章的代码没有做过多的修饰,比较简单,容易理解。
VIT_demo ├─models │ └─vision_transformer.py ├─data │ ├─Black-grass │ ├─Charlock │ ├─Cleavers │ ├─Common Chickweed │ ├─Common wheat │ ├─Fat Hen │ ├─Loose Silky-bent │ ├─Maize │ ├─Scentless Mayweed │ ├─Shepherds Purse │ ├─Small-flowered Cranesbill │ └─Sugar beet ├─mean_std.py ├─makedata.py ├─train.py ├─test1.py └─test.py
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:
from torchvision.datasets import ImageFolder import torch from torchvision import transforms def get_mean_and_std(train_data): train_loader = torch.utils.data.DataLoader( train_data, batch_size=1, shuffle=False, num_workers=0, pin_memory=True) mean = torch.zeros(3) std = torch.zeros(3) for X, _ in train_loader: for d in range(3): mean[d] += X[:, d, :, :].mean() std[d] += X[:, d, :, :].std() mean.div_(len(train_data)) std.div_(len(train_data)) return list(mean.numpy()), list(std.numpy()) if __name__ == '__main__': train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor()) print(get_mean_and_std(train_dataset))
数据集结构:
运行结果:
([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
把这个结果记录下来,后面要用!
我们整理还的图像分类的数据集结构是这样的
data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet
pytorch和keras默认加载方式是ImageNet数据集格式,格式是
├─data │ ├─val │ │ ├─Black-grass │ │ ├─Charlock │ │ ├─Cleavers │ │ ├─Common Chickweed │ │ ├─Common wheat │ │ ├─Fat Hen │ │ ├─Loose Silky-bent │ │ ├─Maize │ │ ├─Scentless Mayweed │ │ ├─Shepherds Purse │ │ ├─Small-flowered Cranesbill │ │ └─Sugar beet │ └─train │ ├─Black-grass │ ├─Charlock │ ├─Cleavers │ ├─Common Chickweed │ ├─Common wheat │ ├─Fat Hen │ ├─Loose Silky-bent │ ├─Maize │ ├─Scentless Mayweed │ ├─Shepherds Purse │ ├─Small-flowered Cranesbill │ └─Sugar beet
新增格式转化脚本makedata.py,插入代码:
import glob import os import shutil image_list=glob.glob('data1/*/*.png') print(image_list) file_dir='data' if os.path.exists(file_dir): print('true') #os.rmdir(file_dir) shutil.rmtree(file_dir)#删除再建立 os.makedirs(file_dir) else: os.makedirs(file_dir) from sklearn.model_selection import train_test_split trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42) train_dir='train' val_dir='val' train_root=os.path.join(file_dir,train_dir) val_root=os.path.join(file_dir,val_dir) for file in trainval_files: file_class=file.replace("\\","/").split('/')[-2] file_name=file.replace("\\","/").split('/')[-1] file_class=os.path.join(train_root,file_class) if not os.path.isdir(file_class): os.makedirs(file_class) shutil.copy(file, file_class + '/' + file_name) for file in val_files: file_class=file.replace("\\","/").split('/')[-2] file_name=file.replace("\\","/").split('/')[-1] file_class=os.path.join(val_root,file_class) if not os.path.isdir(file_class): os.makedirs(file_class) shutil.copy(file, file_class + '/' + file_name)
为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:
pip install torchtoolbox
Cutout实现,在transforms中。
from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout()
])
Mixup实现,在train方法中。需要导入包:from torchtoolbox.tools import mixup_data, mixup_criterion
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
data, labels_a, labels_b, lam = mixup_data(data, target, alpha)
optimizer.zero_grad()
output = model(data)
loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)
loss.backward()
optimizer.step()
print_loss = loss.data.item()
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from models.vision_transformer import deit_tiny_patch16_224
from torchtoolbox.tools import mixup_data, mixup_criterion
from torchtoolbox.transform import Cutout
设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使用CPU。建议使用GPU,CPU太慢了。
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 300
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
数据处理比较简单,加入了Cutout、做了Resize和归一化。在transforms.Normalize中写入上面求得的mean和std的值。
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
Cutout(),
transforms.ToTensor(),
transforms.Normalize([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
])
transform_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
])
使用pytorch默认读取数据的方式,然后将dataset_train.class_to_idx打印出来,预测的时候要用到。
# 读取数据
dataset_train = datasets.ImageFolder('data/train', transform=transform)
dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
print(dataset_train.class_to_idx)
# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
class_to_idx的结果:
{‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3, ‘Common wheat’: 4, ‘Fat Hen’: 5, ‘Loose Silky-bent’: 6, ‘Maize’: 7, ‘Scentless Mayweed’: 8, ‘Shepherds Purse’: 9, ‘Small-flowered Cranesbill’: 10, ‘Sugar beet’: 11}
模型文件来自:https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
我在这个脚本的基础上做了更改,目前可以加载.pth的预训练模型,不能加载.npz的预训练模型。
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = deit_tiny_patch16_224(pretrained=True)
print(model_ft)
num_ftrs = model_ft.head.in_features
model_ft.head = nn.Linear(num_ftrs, 12,bias=True)
nn.init.xavier_uniform_(model_ft.head.weight)
model_ft.to(DEVICE)
print(model_ft)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)
# 定义训练过程 alpha=0.2 def train(model, device, train_loader, optimizer, epoch): model.train() sum_loss = 0 total_num = len(train_loader.dataset) print(total_num, len(train_loader)) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True) data, labels_a, labels_b, lam = mixup_data(data, target, alpha) optimizer.zero_grad() output = model(data) loss = mixup_criterion(criterion, output, labels_a, labels_b, lam) loss.backward() optimizer.step() lr = optimizer.state_dict()['param_groups'][0]['lr'] print_loss = loss.data.item() sum_loss += print_loss if (batch_idx + 1) % 10 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format( epoch, (batch_idx + 1) * len(data), len(train_loader.dataset), 100. * (batch_idx + 1) / len(train_loader), loss.item(),lr)) ave_loss = sum_loss / len(train_loader) print('epoch:{},loss:{}'.format(epoch, ave_loss)) ACC=0 # 验证过程 def val(model, device, test_loader): global ACC model.eval() test_loss = 0 correct = 0 total_num = len(test_loader.dataset) print(total_num, len(test_loader)) with torch.no_grad(): for data, target in test_loader: data, target = Variable(data).to(device), Variable(target).to(device) output = model(data) loss = criterion(output, target) _, pred = torch.max(output.data, 1) correct += torch.sum(pred == target) print_loss = loss.data.item() test_loss += print_loss correct = correct.data.item() acc = correct / total_num avgloss = test_loss / len(test_loader) print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( avgloss, correct, len(test_loader.dataset), 100 * acc)) if acc > ACC: torch.save(model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth') ACC = acc # 训练 for epoch in range(1, EPOCHS + 1): train(model_ft, DEVICE, train_loader, optimizer, epoch) cosine_schedule.step() val(model_ft, DEVICE, test_loader)
运行结果:
我们介绍一种通用的,通过自己手动加载数据集然后做预测,具体操作如下:
测试集存放的目录如下图:
第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!
第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。
第三步 加载model,并将模型放在DEVICE里,
第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。
import torch.utils.data.distributed import torchvision.transforms as transforms from PIL import Image from torch.autograd import Variable import os classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common wheat','Fat Hen', 'Loose Silky-bent', 'Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet') transform_test = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654]) ]) DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = torch.load("model.pth") model.eval() model.to(DEVICE) path='test/' testList=os.listdir(path) for file in testList: img=Image.open(path+file) img=transform_test(img) img.unsqueeze_(0) img = Variable(img).to(DEVICE) out=model(img) # Predict _, pred = torch.max(out.data, 1) print('Image Name:{},predict:{}'.format(file,classes[pred.data.item()]))
运行结果:
完整代码:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/81737304
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。