当前位置:   article > 正文

VGG16训练CIFAR10_tensorflow vgg16训练cifar10代码

tensorflow vgg16训练cifar10代码

模仿的b站小土堆的方法

import torch
import torchvision.datasets

#下载训练接和测试集
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models

train_data=torchvision.datasets.CIFAR10("./data",train=True,
                                        transform=torchvision.transforms.ToTensor(),download=True)
test_data =torchvision.datasets.CIFAR10("./data",train=False,
                                        transform=torchvision.transforms.ToTensor(),download=True)

tn_data_loader=DataLoader(train_data,batch_size=64)
tt_data_loader=DataLoader(test_data,batch_size=64)

#打印训练集和测试集长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print("训练集长度为:{}\n测试集长度为:{}".format(train_data_size,test_data_size))

#修改vgg16网络模型,使其符合我们的训练集
class VGG16_NET(nn.Module):
    def __init__(self):
        super(VGG16_NET, self).__init__()
        net=models.vgg16(True)
        net.classifier=nn.Sequential()
        self.futures=net
        self.classifier=nn.Sequential(
            nn.Flatten(),
            nn.Linear(25088,512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512,128),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(128,10),
        )

    def forward(self,x):
        x=self.futures(x)
        # x=x.view(x.size(0),-1)
        x=self.classifier(x)
        return x

vgg16=VGG16_NET()

if torch.cuda.is_available():
    vgg16=vgg16.cuda()

#定义损失函数和优化器
loss_fn=nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn=loss_fn.cuda()
learning_rate=0.01
optimizer=torch.optim.SGD(params=vgg16.parameters(),lr=learning_rate)

#tensorboard
writer=SummaryWriter("./logs")

epoch=10

train_step=0
test_step=0
for i in range(epoch):
    print("-----第{}轮训练开始-----".format(i+1))
    for data in tn_data_loader:
        imgs,targets=data
        if torch.cuda.is_available():
            imgs=imgs.cuda()
            targets=targets.cuda()

        output=vgg16(imgs)
        optimizer.zero_grad()
        loss=loss_fn(output,targets)
        loss.backward()
        optimizer.step()
        train_step=train_step+1
        if train_step%100==0:
            print("训练次数:{} Loss:{}".format(train_step,loss))
            writer.add_scalar("train_loss",loss,train_step)

    test_loss=0
    total_accuracy=0
    vgg16.eval()
    with torch.no_grad():
        for data in tt_data_loader:
            imgs,targets=data
            if torch.cuda.is_available():
                imgs = imgs.cuda()
                targets = targets.cuda()

            output=vgg16(imgs)
            loss=loss_fn(output,targets)
            test_loss=test_loss+loss
            accuracy=(output.argmax(1)==targets).sum()
            total_accuracy=total_accuracy+accuracy

    print("整体测试集上的Loss:{}\n整体测试集的准确率:{}".format(test_loss,total_accuracy/test_data_size))
    writer.add_scalar("Test_Loss",test_loss,test_step)
    writer.add_scalar("accuracy",total_accuracy,test_step)
    torch.save(vgg16,"vgg16_{}.pth".format(i+1))
    test_step=test_step+1
    print("模型已保存")

writer.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107

最后准确率在百分之87左右

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/957659
推荐阅读
相关标签
  

闽ICP备14008679号