赞
踩
实现多层感知器(Multlayer Perceptron)同样遵循以下步骤:
方法一:从零开始实现
- import torch
- import torch.nn as nn
- import numpy as np
- import d2lzh_pytorch as d2l
-
- # 各层节点数
- num_i = 28 * 28
- num_h = 256
- num_o = 10
-
- # 构建数据
- batch_size = 256
- train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
-
- # 参数初始化
- w1 = torch.tensor(np.random.normal(0, 0.01, (num_i, num_h)), dtype=torch.float32, requires_grad=True)
- b1 = torch.zeros(num_h, requires_grad=True)
- w2 = torch.tensor(np.random.normal(0, 0.01, (num_h, num_o)), dtype=torch.float32, requires_grad=True)
- b2 = torch.zeros(num_o, requires_grad=True)
- params = [w1, b1, w2, b2]
-
- # 激活函数
- def relu(x):
- return torch.max(x, torch.tensor(0.0))
-
- # 模型构建
- def net(x):
- x = x.view(-1, num_i)
- h = relu(x.mm(w1) + b1)
- o = h.mm(w2) + b2
- return o
-
- # 损失函数
- loss = nn.CrossEntropyLoss()
-
- # 训练模型
- num_epochs = 5
- lr = 100.0
- d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)
data:image/s3,"s3://crabby-images/deb9d/deb9d52e6c78f73fbfaadc6e519fd00d286664e1" alt=""
方法二:能调包就不实现
- import torch
- import torch.nn as nn
- import torch.nn.init as init
- import torch.optim as optim
- import d2lzh_pytorch as d2l
-
- # node number of MLP Layer
- num_i, num_h, num_o = 28 * 28, 256, 10
-
- # data load
- batch_size = 256
- train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
-
- # network build
- class MLP(nn.Module):
- def __init__(self, n_i, n_h, n_o):
- super(MLP, self).__init__()
- self.flatten = d2l.FlattenLayer()
- self.linear1 = nn.Linear(n_i, n_h)
- self.relu = nn.ReLU()
- self.linear2 = nn.Linear(n_h, n_o)
-
- def forward(self, input):
- return self.linear2(self.relu(self.linear1(self.flatten(input))))
-
- net = MLP(num_i, num_h, num_o)
- for param in net.parameters():
- init.normal_(param, mean=0, std=0.01)
-
- # loss
- loss = nn.CrossEntropyLoss()
-
- # optimizer
- optimizer = optim.SGD(net.parameters(), lr=0.5)
-
- # train
- num_epochs = 5
- d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, optimizer=optimizer)
data:image/s3,"s3://crabby-images/deb9d/deb9d52e6c78f73fbfaadc6e519fd00d286664e1" alt=""
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。