赞
踩
代码基于花卉识别
1.tqdm
就是展示进度条的函数
2.net.train()、net.eval()
目的验证中不dropout
net = AlexNet(num_classes=5, init_weights=True) net.to(device) loss_function = nn.CrossEntropyLoss() # pata = list(net.parameters()) optimizer = optim.Adam(net.parameters(), lr=0.0002) epochs = 10 save_path = './AlexNet.pth' best_acc = 0.0 train_steps = len(train_loader) for epoch in range(epochs): # train net.train() running_loss = 0.0 train_bar = tqdm(train_loader, file=sys.stdout) for step, data in enumerate(train_bar): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
1. 首先解读下 torch.max(input, dim) 函数
output = torch.max(input, dim)
输入
input
是softmax函数输出的一个tensor
dim
是max函数索引的维度0/1
,0
是每列的最大值,1
是每行的最大值
输出
- 函数会返回两个
tensor
,第一个tensor
是每行的最大值,softmax的输出中最大的是1,所以第一个tensor是全1的tensor
;第二个tensor
是每行最大值的索引。
2. 准确率的计算
代码基于花卉识别5分类 {0:‘daisy’, 1:‘dandelion’, 2:‘roses’, 3:‘sunflower’, 4:‘tulips’},batchsize=4
for val_data in val_bar:
val_images, val_labels = val_data
outputs = net(val_images.to(device))
#print('outputs',outputs)
predict_y = torch.max(outputs, dim=1)[1]
#print('predict_y',predict_y)
acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
output
utputs tensor([[ 1.1575, 0.4821, -1.0027, -0.3505, -0.5067], [-2.1531, -3.3662, 2.7460, -1.5190, 2.3250], [-1.2522, -2.7431, 2.0015, -1.6097, 1.9731], [ 0.3428, 0.4972, -0.9707, 0.4508, -0.5267]], device='cuda:0')
- 1
- 2
- 3
- 4
简单理解,每行代表一张照片对应5种分类的可能性,数值越大说明该照片是此类别的可能性越高。因为batchsize=4,所以一次验证了4张照片,所以tensor有四行,也就是:
predict_y(由torch.max(outputs, dim=1)[1]输出的结果)
predict_y tensor([0, 2, 2, 1], device='cuda:0')
- 1
torch.max(outputs, dim=1)[1]输出为每行(其中dim=1代表行)最大值的索引(其中[1]代表输出索引)
如ouput中第一行[ 1.1575, 0.4821, -1.0027, -0.3505, -0.5067]中1.1575最大,所以输出索引0
也就表示,这张图片索引(标签)是0
通过函数torch.eq(predict_y, val_labels.to(device))与实际标签做对比
如图片predict_y是0,真实标签val_labels也是0,所以验证正确acc++1
如图片predict_y是0,真实标签val_labels是1,所以验证错误acc++0
最终acc中就存储了正确的个数,再除以所有图片个数就得到了正确率
这时候就可以得到到底有多少张验证正确了
save_path = './AlexNet.pth'
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(), save_path)
# validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): val_bar = tqdm(validate_loader, file=sys.stdout) for val_data in val_bar: val_images, val_labels = val_data outputs = net(val_images.to(device)) #print('outputs',outputs) predict_y = torch.max(outputs, dim=1)[1] #print('predict_y',predict_y) acc += torch.eq(predict_y, val_labels.to(device)).sum().item() val_accurate = acc / val_num print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。