当前位置:   article > 正文

深度学习入门 FashionMNIST数据集训练和测试(30层神经网路)_fmnist是几通道的数据

fmnist是几通道的数据

使用pytorch框架。模型包含13层卷积层、2层池化层、15层全连接层。为什么叠这么多层?就是玩。

FashionMNIST数据集包含训练集6w张图片,测试集1w张图片,每张图片是单通道、大小28×28。

import argparse
import torch
import torch.nn as nn  # 指定torch.nn别名nn
import torch.optim as optim
import torchvision  # 一些加载数据的函数及常用的数据集接口
import torchvision.transforms as transforms


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # out 14

            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # out 7
        )
        self.linear = nn.Sequential(
            nn.Linear(128 * 7 * 7, 128),  # 全连接层
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.BatchNorm1d(num_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10),
        )

    def forward(self, inputs):
        out = self.conv(inputs)
        out = out.view(out.size(0), -1)
        logits = self.linear(out)
        return logits


def train(model, device, train_loader, criterion, optimizer):
    model.train()
    total_loss, total_num = 0, 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        logits = model(data)
        loss = criterion(logits, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_num += 1

    return total_loss / total_num


def test(model, device, test_loader):
    model.eval()
    t_correct, t_sum = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            logits = model(data)
            pred = logits.argmax(dim=1)
            t_correct += torch.eq(pred, target).float().sum().item()
            t_sum += data.size(0)
        acc = t_correct / t_sum
        return acc


def main():
    fashionmnist_train = torchvision.datasets.FashionMNIST(
        root='F:/dataset/minist', train=True, download=True,
        transform=transforms.ToTensor(),
    )
    fashionmnist_test = torchvision.datasets.FashionMNIST(
        root='F:/dataset/minist', train=False, download=True,
        transform=transforms.ToTensor(),
    )
    print('训练集和测试集大小:', len(fashionmnist_train), len(fashionmnist_test))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_iter = torch.utils.data.DataLoader(fashionmnist_train, batch_size=args.batch_size, shuffle=True)
    test_iter = torch.utils.data.DataLoader(fashionmnist_test, batch_size=args.test_batch_size)

    model = CNN().to(device)  # 类转换到cuda上
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.001)  # model.parameters() 网络里面的参数
    lr_sch = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)

    for epoch in range(args.epochs):
        train_loss = train(model, device, train_iter, criterion, optimizer)
        acc = test(model, device, test_iter)
        print(f'epoch {epoch:2d}: loss = {train_loss:.6f}; acc={acc:.4f}; lr:{lr_sch.get_last_lr()[0]:.8f}')
        lr_sch.step()  # 执行一次学习率衰减   lr = lr * 衰减率

    if args.save_model:
        torch.save(model.state_dict(), "cnn.pt")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=20, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')

    # 解析参数
    args = parser.parse_args()
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189

运行结果:

训练集和测试集大小: 60000 10000
epoch  0: loss = 0.680793; acc=0.8512; lr:0.00100000
epoch  1: loss = 0.413065; acc=0.8831; lr:0.00090000
epoch  2: loss = 0.375292; acc=0.8931; lr:0.00081000
epoch  3: loss = 0.352440; acc=0.8852; lr:0.00072900
epoch  4: loss = 0.342234; acc=0.8936; lr:0.00065610
epoch  5: loss = 0.316583; acc=0.8845; lr:0.00059049
epoch  6: loss = 0.304564; acc=0.8570; lr:0.00053144
epoch  7: loss = 0.282753; acc=0.9003; lr:0.00047830
epoch  8: loss = 0.259815; acc=0.9098; lr:0.00043047
epoch  9: loss = 0.237412; acc=0.9196; lr:0.00038742
epoch 10: loss = 0.218256; acc=0.9124; lr:0.00034868
epoch 11: loss = 0.205545; acc=0.9228; lr:0.00031381
epoch 12: loss = 0.187289; acc=0.9243; lr:0.00028243
epoch 13: loss = 0.173118; acc=0.9268; lr:0.00025419
epoch 14: loss = 0.156642; acc=0.9301; lr:0.00022877
epoch 15: loss = 0.140609; acc=0.9210; lr:0.00020589
epoch 16: loss = 0.128453; acc=0.9296; lr:0.00018530
epoch 17: loss = 0.116559; acc=0.9321; lr:0.00016677
epoch 18: loss = 0.104093; acc=0.9324; lr:0.00015009
epoch 19: loss = 0.089764; acc=0.9308; lr:0.00013509
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/239277
推荐阅读
相关标签
  

闽ICP备14008679号