当前位置:   article > 正文

【Pytorch代码学习】——训练部分_train_bar

train_bar

代码基于花卉识别

训练train

代码中部分函数讲解

1.tqdm
就是展示进度条的函数
在这里插入图片描述
2.net.train()、net.eval()
目的验证中不dropout

完整训练代码

net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

测试test

测试中准确率的计算

1. 首先解读下 torch.max(input, dim) 函数

output = torch.max(input, dim)

输入

  • input是softmax函数输出的一个tensor
  • dim是max函数索引的维度0/10是每列的最大值,1是每行的最大值

输出

  • 函数会返回两个tensor,第一个tensor是每行的最大值,softmax的输出中最大的是1,所以第一个tensor是全1的tensor;第二个tensor是每行最大值的索引。

2. 准确率的计算
代码基于花卉识别5分类 {0:‘daisy’, 1:‘dandelion’, 2:‘roses’, 3:‘sunflower’, 4:‘tulips’},batchsize=4

            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                #print('outputs',outputs)
                predict_y = torch.max(outputs, dim=1)[1]
                #print('predict_y',predict_y)
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

output

utputs tensor([[ 1.1575,  0.4821, -1.0027, -0.3505, -0.5067],
       [-2.1531, -3.3662,  2.7460, -1.5190,  2.3250],
       [-1.2522, -2.7431,  2.0015, -1.6097,  1.9731],
       [ 0.3428,  0.4972, -0.9707,  0.4508, -0.5267]], device='cuda:0')
  • 1
  • 2
  • 3
  • 4

简单理解,每行代表一张照片对应5种分类的可能性,数值越大说明该照片是此类别的可能性越高。因为batchsize=4,所以一次验证了4张照片,所以tensor有四行,也就是:

  • 5分类对应5列
  • batchsize=4对应4行

predict_y(由torch.max(outputs, dim=1)[1]输出的结果)

predict_y tensor([0, 2, 2, 1], device='cuda:0')
  • 1

torch.max(outputs, dim=1)[1]输出为每行(其中dim=1代表行)最大值的索引(其中[1]代表输出索引)
如ouput中第一行[ 1.1575, 0.4821, -1.0027, -0.3505, -0.5067]中1.1575最大,所以输出索引0
也就表示,这张图片索引(标签)是0
通过函数torch.eq(predict_y, val_labels.to(device))与实际标签做对比
如图片predict_y是0,真实标签val_labels也是0,所以验证正确acc++1
如图片predict_y是0,真实标签val_labels是1,所以验证错误acc++0
最终acc中就存储了正确的个数,再除以所有图片个数就得到了正确率
这时候就可以得到到底有多少张验证正确了

保存模型

save_path = './AlexNet.pth'
  • 1
if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
  • 1
  • 2
  • 3

完整测试代码

# validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                #print('outputs',outputs)
                predict_y = torch.max(outputs, dim=1)[1]
                #print('predict_y',predict_y)
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/92628
推荐阅读
相关标签
  

闽ICP备14008679号