赞
踩
使用pytorch框架。模型包含13层卷积层、2层池化层、15层全连接层。为什么叠这么多层?就是玩。
FashionMNIST数据集包含训练集6w张图片,测试集1w张图片,每张图片是单通道、大小28×28。
import argparse import torch import torch.nn as nn # 指定torch.nn别名nn import torch.optim as optim import torchvision # 一些加载数据的函数及常用的数据集接口 import torchvision.transforms as transforms class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=32), nn.ReLU(inplace=True), nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # out 14 nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.BatchNorm2d(num_features=128), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # out 7 ) self.linear = nn.Sequential( nn.Linear(128 * 7 * 7, 128), # 全连接层 nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 128), nn.BatchNorm1d(num_features=128), nn.ReLU(inplace=True), nn.Linear(128, 10), ) def forward(self, inputs): out = self.conv(inputs) out = out.view(out.size(0), -1) logits = self.linear(out) return logits def train(model, device, train_loader, criterion, optimizer): model.train() total_loss, total_num = 0, 0 for data, target in train_loader: data, target = data.to(device), target.to(device) optimizer.zero_grad() logits = model(data) loss = criterion(logits, target) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() total_num += 1 return total_loss / total_num def test(model, device, test_loader): model.eval() t_correct, t_sum = 0, 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) logits = model(data) pred = logits.argmax(dim=1) t_correct += torch.eq(pred, target).float().sum().item() t_sum += data.size(0) acc = t_correct / t_sum return acc def main(): fashionmnist_train = torchvision.datasets.FashionMNIST( root='F:/dataset/minist', train=True, download=True, transform=transforms.ToTensor(), ) fashionmnist_test = torchvision.datasets.FashionMNIST( root='F:/dataset/minist', train=False, download=True, transform=transforms.ToTensor(), ) print('训练集和测试集大小:', len(fashionmnist_train), len(fashionmnist_test)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_iter = torch.utils.data.DataLoader(fashionmnist_train, batch_size=args.batch_size, shuffle=True) test_iter = torch.utils.data.DataLoader(fashionmnist_test, batch_size=args.test_batch_size) model = CNN().to(device) # 类转换到cuda上 criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.001) # model.parameters() 网络里面的参数 lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9) for epoch in range(args.epochs): train_loss = train(model, device, train_iter, criterion, optimizer) acc = test(model, device, test_iter) print(f'epoch {epoch:2d}: loss = {train_loss:.6f}; acc={acc:.4f}; lr:{lr_sch.get_last_lr()[0]:.8f}') lr_sch.step() # 执行一次学习率衰减 lr = lr * 衰减率 if args.save_model: torch.save(model.state_dict(), "cnn.pt") if __name__ == '__main__': parser = argparse.ArgumentParser(description='test') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.001, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') # 解析参数 args = parser.parse_args() main()
运行结果:
训练集和测试集大小: 60000 10000 epoch 0: loss = 0.680793; acc=0.8512; lr:0.00100000 epoch 1: loss = 0.413065; acc=0.8831; lr:0.00090000 epoch 2: loss = 0.375292; acc=0.8931; lr:0.00081000 epoch 3: loss = 0.352440; acc=0.8852; lr:0.00072900 epoch 4: loss = 0.342234; acc=0.8936; lr:0.00065610 epoch 5: loss = 0.316583; acc=0.8845; lr:0.00059049 epoch 6: loss = 0.304564; acc=0.8570; lr:0.00053144 epoch 7: loss = 0.282753; acc=0.9003; lr:0.00047830 epoch 8: loss = 0.259815; acc=0.9098; lr:0.00043047 epoch 9: loss = 0.237412; acc=0.9196; lr:0.00038742 epoch 10: loss = 0.218256; acc=0.9124; lr:0.00034868 epoch 11: loss = 0.205545; acc=0.9228; lr:0.00031381 epoch 12: loss = 0.187289; acc=0.9243; lr:0.00028243 epoch 13: loss = 0.173118; acc=0.9268; lr:0.00025419 epoch 14: loss = 0.156642; acc=0.9301; lr:0.00022877 epoch 15: loss = 0.140609; acc=0.9210; lr:0.00020589 epoch 16: loss = 0.128453; acc=0.9296; lr:0.00018530 epoch 17: loss = 0.116559; acc=0.9321; lr:0.00016677 epoch 18: loss = 0.104093; acc=0.9324; lr:0.00015009 epoch 19: loss = 0.089764; acc=0.9308; lr:0.00013509
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。