赞
踩
附上代码:
- import os
- import sys
- import json
-
- import torch
- import torch.nn as nn
- from torchvision import transforms, datasets, utils
- import matplotlib.pyplot as plt
- import numpy as np
- import torch.optim as optim
- from tqdm import tqdm
-
- from model import AlexNet
-
-
- def main():
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#表示如果当前有可使用的GPU设备,就使用设备上的GPU设备,没有就是用CPU设备
- print("using {} device.".format(device))
-
- data_transform = {
- "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪,将图像裁剪至224X224的大小
- transforms.RandomHorizontalFlip(),#在水平放心随机翻转
- transforms.ToTensor(),#转换成tensor
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
- "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
-
- data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path #返回绝对路径
- image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
- assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
- train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
- transform=data_transform["train"])
- train_num = len(train_dataset)#打印训练集有多少张图片
-
- # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
- flower_list = train_dataset.class_to_idx#利用class_to_idx去获取分类名称所对应的索引
- cla_dict = dict((val, key) for key, val in flower_list.items())
- # write dict into json file
- json_str = json.dumps(cla_dict, indent=4)
- with open('class_indices.json', 'w') as json_file:
- json_file.write(json_str)
-
- batch_size = 32
- nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
- print('Using {} dataloader workers every process'.format(nw))
-
- train_loader = torch.utils.data.DataLoader(train_dataset,
- batch_size=batch_size, shuffle=True,
- num_workers=nw)
-
- validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
- transform=data_transform["val"])
- val_num = len(validate_dataset)
- validate_loader = torch.utils.data.DataLoader(validate_dataset,
- batch_size=4, shuffle=True,
- num_workers=nw)
-
- print("using {} images for training, {} images for validation.".format(train_num,
- val_num))
- # test_data_iter = iter(validate_loader)
- # test_image, test_label = test_data_iter.__next__()
- #
- # def imshow(img):
- # img = img / 2 + 0.5 # unnormalize
- # npimg = img.numpy()
- # plt.imshow(np.transpose(npimg, (1, 2, 0)))
- # plt.show()
- #
- # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
- # imshow(utils.make_grid(test_image))
-
- net = AlexNet(num_classes=5, init_weights=True)
-
- net.to(device)
- loss_function = nn.CrossEntropyLoss()
- # pata = list(net.parameters())
- optimizer = optim.Adam(net.parameters(), lr=0.0002)
-
- epochs = 10
- save_path = './AlexNet.pth'
- best_acc = 0.0
- train_steps = len(train_loader)
- for epoch in range(epochs):
- # train
- net.train()
- running_loss = 0.0
- train_bar = tqdm(train_loader, file=sys.stdout)
- for step, data in enumerate(train_bar):
- images, labels = data
- optimizer.zero_grad()
- outputs = net(images.to(device))
- loss = loss_function(outputs, labels.to(device))
- loss.backward()
- optimizer.step()
-
- # print statistics
- running_loss += loss.item()
-
- train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
- epochs,
- loss)
-
- # validate
- net.eval()
- acc = 0.0 # accumulate accurate number / epoch
- with torch.no_grad():
- val_bar = tqdm(validate_loader, file=sys.stdout)
- for val_data in val_bar:
- val_images, val_labels = val_data
- outputs = net(val_images.to(device))
- predict_y = torch.max(outputs, dim=1)[1]
- acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
-
- val_accurate = acc / val_num
- print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
- (epoch + 1, running_loss / train_steps, val_accurate))
-
- if val_accurate > best_acc:
- best_acc = val_accurate
- torch.save(net.state_dict(), save_path)
-
- print('Finished Training')
-
-
- if __name__ == '__main__':
- main()
解释顺序就是代码阅读顺序
训练数据集处理:
1.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
同
- if torch.cuda.is_available():
- device = torch.device("cuda")
- else:
- device = torch.device("cpu")
表示如果有GPU使用GPU进行计算训练,否则使用CPU
2.
print("using {} device.".format(device))
一种格式化字符串的函数str.format()Python format 格式化函数 | 菜鸟教程 (runoob.com)
3.transforms.Compose():预处理函数
4.transforms.RandomResizedCrop(224):随机裁剪,将图像裁剪至224X224的大小
5.transforms.RandomHorizontalFlip():随机翻转,数据增强一种方法,这里是水平翻转。
6.transforms.ToTensor():转换成张量
7.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):标准化处理,第一个(0.5, 0.5, 0.5)为均值,第二个(0.5, 0.5, 0.5)为方差
测试数据集处理:
1.transforms.Resize((224, 224):图像大小改为224
2.transforms.ToTensor():转换成张量
3.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):标准化处理
获取数据集:
1.
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
os.path.abspath():获取数据集所在根目录,即返回绝对路径
os.path.join():将传入两个路径连接在一起
os.getcwd():获取当前所在文件的目录
"../..":返回上上级目录
2.
image_path = os.path.join(data_root, "data_set", "flower_data")
从根目录开始向下进行完整目录的拼接
3.
assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
os.path.exists():函数的功能是查看给定的文件/目录是否存在,存在返回True,不存在返回False。
assert:Python3 assert(断言) | 菜鸟教程 (runoob.com)
为assert断言语句添加异常参数:assert的异常参数,其实就是在断言表达式后添加字符串信息,用来解释断言并更好的知道是哪里出了问题。格式如下:
assert expression [, arguments]
assert 表达式 [, 参数]
4.
- train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
- transform=data_transform["train"])
datasets.ImageFolder():加载数据集
root=image_path + "/train":传入训练集数据路径
transform=data_transform["train"]:调用训练数据集预处理模块 即:
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
5. train_num = len(train_dataset):打印训练集有多少张图片
分类处理:
1.flower_list = train_dataset.class_to_idx:利用class_to_idx去获取分类名称所对应的索引
2.cla_dict = dict((val, key) for key, val in flower_list.items()) :循环遍历数组索引该值并交换重新赋值给数组,这样模型预测出来的直接就是value类别值。Python 字典(Dictionary) items()方法 | 菜鸟教程 (runoob.com)
3.json_str = json.dumps(cla_dict, indent=4):把字典编码成json格式,indent参数决定添加几个空格
4. with open('class_indices.json', 'w') as json_file: json_file.write(json_str):把字典类别索引写入json文件
5.batch_size = 32:一次性载入32张图片
6.torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=0):加载数据集和其他参数
7.datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"]):加载测试集路径和测试集预处理模块。
查看数据集代码:
- test_data_iter = iter(validate_loader)
- test_image, test_label = test_data_iter.__next__()
-
- def imshow(img):
- img = img / 2 + 0.5 # unnormalize
- npimg = img.numpy()
- plt.imshow(np.transpose(npimg, (1, 2, 0)))
- plt.show()
-
- print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
- imshow(utils.make_grid(test_image))
使用设计的模型:
1.AlexNet(num_classes=5, init_weights=True):传入参数调用类
2.num_classes=5:类别数为5,最终全连接生成5个值的向量。
3.init_weights=True:初始化模型训练权重参数,从头开始训练
4.net.to(device):设备(GPU或CPU)加载网络。
5.loss_function = nn.CrossEntropyLoss():设置损失函数。CrossEntropyLoss为交叉熵损失函数。
6.optimizer = optim.Adam(net.parameters(), lr=0.0002):设置Adam优化器。根据模型当前参数决定优化即调整参数的增减和幅度。Ir为学习率。过大过小都会影响准确率。
7.save_path = './AlexNet.pth':设置保存权重的路径
8.best_acc = 0.0:设置准确率变量
开始训练:
1.for epoch in range(10):遍历迭代10次
2.net.train():调用Dropout方法
3.running_loss = 0.0:设置训练损失值
4.t1 = time.perf_counter():设置记录训练开始时间以计算一个epoch所花费时间
5.for step, data in enumerate(train_loader, start=0):遍历数据集,返回数据data和步长step
6.images, labels = data:把data数组中的图像和标签分别赋值给变量images和label。
7.optimizer.zero_grad():清空之前的梯度信息。作用是将历史损失梯度进行清零,一般batch_size这个数值设置的越大,训练效果越好,但由于硬件设备受限,不可能用一个很大的batch_size进行训练,而通过Optimizer.zero_grad()可以变相实现一个很大batch数目的训练,即一次性计算多个小的batch。
8.outputs = net(images.to(device)):开始进行正向传播,并把图像计算与设备进行绑定。
9.loss = loss_function(outputs, labels.to(device)):训练得到预测输出之后与真实标签进行计算损失值。
10.loss.backward():将loss反向传播到各个节点。
11.optimizer.step():更新每个节点参数
12.running_loss += loss.item():进行一个loss的累加
13.rate = (step + 1) / len(train_loader):当前训练步数,比如72/900、73/900。 len(train_loader)表示训练一轮需要的步数
14.a = "*" * int(rate * 50),b = "." * int((1 - rate) * 50):使用*和.,打印进度百分比
15.print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end=""):打印训练进度信息
验证过程:
1.net.eval():控制不使用Dropout。
2. with torch.no_grad():该函数在接下来计算过程中不要去计算每个节点的误差损失梯度,如果没有使用它,在运行时就会消耗更多的算力,占用更多的内存资源,甚至会内存崩掉。在这个函数的所有范围内的计算都不会去计算它的误差梯度
3.for val_data in validate_loader:遍历验证数据集
4.val_images, val_labels = val_data:把data数组中的图像和标签分别赋值给变量images和label
5.outputs = net(val_images.to(device)):开始进行正向传播,并把图像计算与设备进行绑定
6.predict_y = torch.max(outputs, dim=1)[1]:这代码的意思是获得这个batch中网络的预测标签,torch.max(outputs, dim=1)返回两个值,分别是最大值和其对应的索引,dim=1时按行返回最大索引,dim=0时按列返回最大索引。
7.acc += (predict_y == val_labels.to(device)).sum().item():将预测值与真实值进行比较,相等为1,不等为0,并求和,并通过item()获得数值。
8.val_accurate = acc / val_num:预测正确值之和除以总和计算准确率
9. if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path):判断当前准确率是否大于历史准确率,如果是保存当前模型权重
10.print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))
print('Finished Training'):打印训练轮数、损失值、步长、准确率,结束训练。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。