当前位置:   article > 正文

PyTorch示例——MLP图像分类-手写数字_pytorch mlp

pytorch mlp

版本信息

  • PyTorch: 1.12.1
  • Python: 3.7.13

导包

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Normalize, Compose
  • 1
  • 2
  • 3
  • 4
  • 5

数据集 MNIST

  • 探索一下数据
explore_data = datasets.MNIST(
    root="./data",
    train=True,
    download=True
)
# 取第0张图片
explore_data[0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 输出信息
(<PIL.Image.Image image mode=L size=28x28 at 0x7F9E30058FD0>, 5)
  • 1
  • 图像展示
explore_data[0][0]
  • 1

image.png

  • 标签
explore_data[0][1]
  • 1
5
  • 1
  • 多遍历几张图片看看(下面这个方法在PyTorch的图像数据中都很通用)
import matplotlib.pyplot as plt

def show_images(n_rows, n_cols, x_data):
    assert n_rows * n_cols < len(x_data)
    
    plt.figure(figsize=(n_cols * 1.5, n_rows * 1.5))
    for row in range(n_rows):
        for col in range(n_cols):
            index = row * n_cols + col
            plt.subplot(n_rows, n_cols, index + 1)
            plt.imshow(x_data[index][0], cmap="binary", interpolation="nearest")  # 图像
            plt.axis("off")
            plt.title(x_data[index][1])  # 标签
    plt.show()

show_images(3, 5, explore_data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

手写数据图片

  • 正式处理数据
transform_funcs = Compose([
    ToTensor(),
    Normalize((0.1307, ), (0.3081, ))  # 标准化,手写数字数据集的通用参数
])

train_data = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform_funcs
)
test_data = datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform_funcs
)

print(train_data.data.shape)
print(test_data.data.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 输出信息
torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])
  • 1
  • 2

构建模型 MLP

class MLPModel(nn.Module):
    
  def __init__(self):
    super(MLPModel, self).__init__()
    self.flatten = nn.Flatten()  # 将二维图像展开为一维
    self.linear1 = nn.Linear(28 * 28, 512)
    self.relu = nn.ReLU()
    self.linear2 = nn.Linear(512, 256)
    self.linear3 = nn.Linear(256, 10)
    
  def forward(self, x):
    out = self.flatten(x)
    out = self.linear1(out)
    out = self.relu(out)
    out = self.linear2(out)
    out = self.relu(out)
    out = self.linear3(out)
    return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

开始训练

# 参数配置
epoch_num = 10
batch_size = 64
learning_rate = 0.0005

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据加载器
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# 模型、损失函数、优化器
model = MLPModel().to(device)
# 交叉熵损失的计算包含了softmax,模型中不需要做softmax
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


# 训练
loss_list = []
for epoch in range(epoch_num):
    for i, (X_train, y_train) in enumerate(train_loader):
        X_train = X_train.to(device)
        pred = model(X_train)
        y_train = y_train.to(device)
        l = loss(pred, y_train)
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        if (i + 1) % 100 == 0:
            print(f"Train... [epoch {epoch + 1}/{epoch_num}, step {i + 1}/{len(train_loader)}]\t[loss {l.item()}]")
    
    loss_list.append(l.item())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
Train... [epoch 1/10, step 100/938]	[loss 0.3897044062614441]
Train... [epoch 1/10, step 200/938]	[loss 0.4615764319896698]
Train... [epoch 1/10, step 300/938]	[loss 0.09734677523374557]
Train... [epoch 1/10, step 400/938]	[loss 0.07513687759637833]
Train... [epoch 1/10, step 500/938]	[loss 0.14482976496219635]
Train... [epoch 1/10, step 600/938]	[loss 0.13744832575321198]
Train... [epoch 1/10, step 700/938]	[loss 0.05915261059999466]
Train... [epoch 1/10, step 800/938]	[loss 0.15903039276599884]
Train... [epoch 1/10, step 900/938]	[loss 0.08145355433225632]
Train... [epoch 2/10, step 100/938]	[loss 0.09195274114608765]
Train... [epoch 2/10, step 200/938]	[loss 0.021573053672909737]
Train... [epoch 2/10, step 300/938]	[loss 0.12080539762973785]
Train... [epoch 2/10, step 400/938]	[loss 0.04143274948000908]
Train... [epoch 2/10, step 500/938]	[loss 0.06194964796304703]
Train... [epoch 2/10, step 600/938]	[loss 0.00492143863812089]
Train... [epoch 2/10, step 700/938]	[loss 0.03946655988693237]
Train... [epoch 2/10, step 800/938]	[loss 0.06333575397729874]
Train... [epoch 2/10, step 900/938]	[loss 0.1421077400445938]
......
Train... [epoch 10/10, step 100/938]	[loss 0.0019266537856310606]
Train... [epoch 10/10, step 200/938]	[loss 0.010688461363315582]
Train... [epoch 10/10, step 300/938]	[loss 0.006648594979196787]
Train... [epoch 10/10, step 400/938]	[loss 0.009121301583945751]
Train... [epoch 10/10, step 500/938]	[loss 0.0005547187756747007]
Train... [epoch 10/10, step 600/938]	[loss 0.04679783061146736]
Train... [epoch 10/10, step 700/938]	[loss 0.002314511453732848]
Train... [epoch 10/10, step 800/938]	[loss 0.02470582351088524]
Train... [epoch 10/10, step 900/938]	[loss 0.010127813555300236]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

绘制训练曲线

import matplotlib.pyplot as plt

plt.plot(range(epoch_num), loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

image.png

测试

test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
with torch.no_grad():
    correct = 0
    total = 0
    for X_test, y_test in test_loader:
        X_test = X_test.to(device)
        y_test = y_test.to(device)
        output = model(X_test)
        _, pred = torch.max(output, 1)
        total += y_test.size(0)  # 总数
        correct += (pred == y_test).sum().item()  # 预测对的数量

    print(f'total = {total}, acurrcy = {100 * correct / total}%')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
total = 10000, acurrcy = 97.75%
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/235332
推荐阅读
相关标签
  

闽ICP备14008679号