赞
踩
import torch
from torch import nn
from net import MyAlexNet
import numpy as np
from torch.optim import lr_scheduler
import os
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
2.定义train和test数据集,把这两个数据集的路径写好;记得反斜杠,路径都是绝对路径,要根据自己的路径进行配置,不难吧
ROOT_TRAIN = r'C:/Users/Desktop/AlexNet/data/train'
ROOT_TEST = r'C:/Users/Desktop/AlexNet/data/val'
3.定义将数据进行归一化,其实就是将图片的像素点都固定在[0,1]之间,作用是这种归一化操作可以加快模型的训练速度,这个方法为啥能加快模型的训练速度,网上有公式可以看看,该操作不使用也可以
# 将图像的像素值归一化到【-1, 1】之间
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
4.定义对数据进行数据的操作,翻转,选装等都可以不适用,但是 transforms.ToTensor()是必不可以少的,必须有
训练集有RandomVerticalFlip数据增强
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
normalize])
测试集没有RandomVerticalFlip数据增强
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize])
5.使用ImageFold传入train和test数据集
train_dataset = ImageFolder(ROOT_TRAIN, transform=train_transform)
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)
6.使用DataLoader方法将ImageFold传入的数据放在GPU上
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)
确定使用GPU还是CPU跑程序
device = 'cuda' if torch.cuda.is_available() else 'cpu'
7.将模型放在GPU设备上跑,如果你没有GPU就是放在CPU设备上跑,其中本文的模型名字是:“MyAlexNet”,如果你的不是请改成你自己的,比如:ResNet-50等
model = MyAlexNet().to(device)
8.定义损失函数,定义优化器,定义学习率,第二行代码是将模型参数等放入优化器
# 定义一个损失函数
loss_fn = nn.CrossEntropyLoss()
# 定义一个优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# 学习率每隔10轮变为原来的0.5
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
9.定义训练函数,该方法是比较重要的,将数据、模型结构、损失函数、优化器等传入模型。具体的代码如下:
定义训练函数
# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
loss, current, n = 0.0, 0.0, 0
for batch, (x, y) in enumerate(dataloader):
image, y = x.to(device), y.to(device)
output = model(image)
cur_loss = loss_fn(output, y)
_, pred = torch.max(output, axis=1)
cur_acc = torch.sum(y==pred) / output.shape[0]
# 反向传播
optimizer.zero_grad()
cur_loss.backward()
optimizer.step()
loss += cur_loss.item()
current += cur_acc.item()
n = n+1
train_loss = loss / n
train_acc = current / n
print('train_loss' + str(train_loss))
print('train_acc' + str(train_acc))
return train_loss, train_acc
定义测试函数
# 定义一个验证函数
def val(dataloader, model, loss_fn):
# 将模型转化为验证模型
model.eval()
loss, current, n = 0.0, 0.0, 0
with torch.no_grad():
for batch, (x, y) in enumerate(dataloader):
image, y = x.to(device), y.to(device)
output = model(image)
cur_loss = loss_fn(output, y)
_, pred = torch.max(output, axis=1)
cur_acc = torch.sum(y == pred) / output.shape[0]
loss += cur_loss.item()
current += cur_acc.item()
n = n + 1
val_loss = loss / n
val_acc = current / n
print('val_loss' + str(val_loss))
print('val_acc' + str(val_acc))
return val_loss, val_acc
10,到了这一步基本上完成了分类模型的搭建,最后是定义画图函数
# 定义画图函数
def matplot_loss(train_loss, val_loss):
plt.plot(train_loss, label='train_loss')
plt.plot(val_loss, label='val_loss')
plt.legend(loc='best')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.title("训练集和验证集loss值对比图")
plt.show()
def matplot_acc(train_acc, val_acc):
plt.plot(train_acc, label='train_acc')
plt.plot(val_acc, label='val_acc')
plt.legend(loc='best')
plt.ylabel('acc')
plt.xlabel('epoch')
plt.title("训练集和验证集acc值对比图")
plt.show()
11.开始训练模型
# 开始训练
loss_train = []
acc_train = []
loss_val = []
acc_val = []
epoch = 20
min_acc = 0
for t in range(epoch):
lr_scheduler.step()
print(f"epoch{t+1}\n-----------")
train_loss, train_acc = train(train_dataloader, model, loss_fn, optimizer)
val_loss, val_acc = val(val_dataloader, model, loss_fn)
loss_train.append(train_loss)
acc_train.append(train_acc)
loss_val.append(val_loss)
acc_val.append(val_acc)
# 保存最好的模型权重
if val_acc >min_acc:
folder = 'save_model'
if not os.path.exists(folder):
os.mkdir('save_model')
min_acc = val_acc
print(f"save best model, 第{t+1}轮")
torch.save(model.state_dict(), 'save_model/best_model.pth')
# 保存最后一轮的权重文件
if t == epoch-1:
torch.save(model.state_dict(), 'save_model/last_model.pth')
matplot_loss(loss_train, loss_val)
matplot_acc(acc_train, acc_val)
print('Done!')
注释:你除了需要自己写网络模型,将调用的模型改成自己的模型,其他的都可以复制粘贴进行,将上面所有的代码放到一个文件夹里面就行,然后修改调用自己的模型。
再次感谢这个博主做的视频讲座:https://www.bilibili.com/video/BV18L4y167jr?p=1
可以学习他的方法来看我这个博客,效率会好很多
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。