赞
踩
本文章只描述如何训练一个脉冲神经网络,关于脉冲神经网络的具体原理、神经元推导以及snntorch是如何首先训练的内容,放到第三节。
本文借助snntorch库搭建三层脉冲神经网络,并在mnist数据集上进行训练。以下是每个包的作用:
snntorch 和其子模块:
snntorch
: snntorch 是一个用于构建和训练脉冲神经网络(SNN)的库。它提供了一系列工具和模块,包括神经元模型、脉冲生成器、可视化工具等。snntorch.spikeplot
: 这是 snntorch 提供的用于绘制脉冲数据的模块。snntorch.spikegen
: 包含用于生成脉冲信号的工具。torch 和其子模块:
torch
: PyTorch 是一个用于深度学习的开源机器学习库,提供了张量计算、自动微分等功能。torch.nn
: PyTorch 的神经网络模块,用于构建神经网络层和模型。torch.utils.data
: 提供用于处理数据加载和处理的工具,包括DataLoader
类。datasets 和 transforms:
torchvision.datasets
: 包含用于加载常见视觉数据集的工具。torchvision.transforms
: 提供对图像进行转换和预处理的工具,例如调整大小、转换为灰度、标准化等。matplotlib 和 numpy:
matplotlib.pyplot
: 用于绘制图表和可视化的库。numpy
: 提供对多维数组进行高效操作的库。PyTorch 提供模型架构,snntorch 提供脉冲神经模型,两者结合实现脉冲神经网络的训练和推断。
- import snntorch as snn
- from snntorch import spikeplot as splt
- from snntorch import spikegen
-
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- from torchvision import datasets, transforms
-
- import matplotlib.pyplot as plt
- import numpy as np
- import itertools
首先加载mnist数据集,初次运行会下载该数据集,并保存到同目录下的data文件夹中(data文件夹是自动创建的),若要修改保存路径,修改data_path变量值即可。
之后创建一个transform,对输入数据进行预处理。
transform
被定义为一个包含多个转换的Compose
对象,这些转换将应用于输入的图像数据。具体来说,这里的转换包括:
transforms.Resize((28, 28))
: 将图像大小调整为 28x28 像素。transforms.Grayscale()
: 将图像转换为灰度图。transforms.ToTensor()
: 将图像转换为 PyTorch 张量。transforms.Normalize((0,), (1,))
: 对图像进行标准化,将像素值从 [0, 1] 缩放到均值为 0、标准差为 1。这些转换的目的是将输入的手写数字图像转换为网络训练所需的格式。例如,将图像大小调整为统一的大小,将图像转为灰度以简化处理,将图像转换为张量以在 PyTorch 中进行处理,最后进行标准化以提高训练的稳定性。
- # dataloader的参数
- batch_size = 128
- data_path='/data/mnist'
-
- dtype = torch.float
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
- print(torch.cuda.is_available())
-
- # 定义transform
- transform = transforms.Compose([
- transforms.Resize((28, 28)),
- transforms.Grayscale(),
- transforms.ToTensor(),
- transforms.Normalize((0,), (1,))])
-
- mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
- mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
该网络结构实际上是一个多层感知器(MLP)结构,即输入层、隐藏层和输出层三层结构。
输入层:
- 输入层的大小由
num_inputs = 28 * 28
决定,即输入图像的像素数。在 MNIST 数据集中,每个图像是 28x28 像素的灰度图像。隐藏层:
隐藏层使用了一个线性层 (
nn.Linear
),其中包含num_hidden = 1000
个神经元。这意味着有 1000 个隐藏层神经元,每个与输入层的每个像素连接。每个隐藏层神经元后面连接了一个 Leaky 脉冲神经元 (
snn.Leaky
),它的漏电参数(leak parameter)由beta
控制,被初始化为 0.95。Leaky 脉冲神经元模型允许在没有输入时渐变地释放脉冲。输出层:
输出层也是一个线性层,其中包含
num_outputs = 10
个神经元,对应于 MNIST 数据集中的 10 个数字类别。这意味着网络的目标是对输入的手写数字图像进行分类,输出图像属于 0 到 9 中的哪一个数字。输出层的每个神经元后面同样连接了一个 Leaky 脉冲神经元。
前向传播:在前向传播中,输入图像通过线性层传递到隐藏层,然后经过 Leaky 脉冲神经元的激活。接着,激活的结果再通过线性层传递到输出层,最后输出层的 Leaky 脉冲神经元生成最终的脉冲输出。整个前向传播过程在时间上重复了
num_steps = 25
步,每一步都记录了脉冲神经元的脉冲输出和膜电位。这种架构使用脉冲编码和脉冲神经元的时间演化,以捕捉输入图像的时间动态信息。
- # 创建训练和测试用的DataLoaders
- train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
- test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
-
- # 网络结构
- num_inputs = 28*28
- num_hidden = 1000
- num_outputs = 10
-
- # 时间参数
- num_steps = 25
- beta = 0.95
-
- # 定义网络结构
- class Net(nn.Module):
- def __init__(self):
- super().__init__()
-
- # Initialize layers
- self.fc1 = nn.Linear(num_inputs, num_hidden)
- self.lif1 = snn.Leaky(beta=beta)
- self.fc2 = nn.Linear(num_hidden, num_outputs)
- self.lif2 = snn.Leaky(beta=beta)
-
- def forward(self, x):
-
- # Initialize hidden states at t=0
- mem1 = self.lif1.init_leaky()
- mem2 = self.lif2.init_leaky()
-
- # 记录输出层的脉冲
- spk2_rec = []
- mem2_rec = []
-
- for step in range(num_steps):
- cur1 = self.fc1(x)
- spk1, mem1 = self.lif1(cur1, mem1)
- cur2 = self.fc2(spk1)
- spk2, mem2 = self.lif2(cur2, mem2)
- spk2_rec.append(spk2)
- mem2_rec.append(mem2)
-
- return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
-
- # Load the network onto CUDA if available
- net = Net().to(device)
使用脉冲神经网络(SNN)进行前向和反向传播。
- '''初始化网络和优化器:网络结构被定义为 Net 类,并加载到设备(CPU 或 GPU)上。
- 优化器选择了 Adam 优化器,学习率为 lr=5e-4,动量参数为 (0.9, 0.999)。
- '''
- loss = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
-
- num_epochs = 1
- import snntorch as snn
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- from torchvision import datasets, transforms
- import matplotlib.pyplot as plt
- import numpy as np
-
- # ... (之前的代码)
-
- # 记录训练和测试损失的列表
- loss_hist = []
- test_loss_hist = []
-
- # 记录训练和测试准确率的列表
- train_accuracy_hist = []
- test_accuracy_hist = []
-
- # 外部训练循环
- for epoch in range(num_epochs):
- iter_counter = 0
- train_batch = iter(train_loader)
-
- # 内部训练循环
- for data, targets in train_batch:
- data = data.to(device)
- targets = targets.to(device)
-
- # 前向传播
- net.train()
- spk_rec, mem_rec = net(data.view(batch_size, -1))
-
- # 初始化损失并进行时间步长的累加
- loss_val = torch.zeros((1), dtype=dtype, device=device)
- for step in range(num_steps):
- loss_val += loss(mem_rec[step], targets)
-
- # 梯度计算 + 权重更新
- optimizer.zero_grad()
- loss_val.backward()
- optimizer.step()
-
- # 记录训练损失
- loss_hist.append(loss_val.item())
-
- # 在每个迭代的末尾,使用测试集计算测试损失和准确率
- if iter_counter % 50 == 0:
- with torch.no_grad():
- net.eval()
- test_data, test_targets = next(iter(test_loader))
- test_data = test_data.to(device)
- test_targets = test_targets.to(device)
-
- # 测试集前向传播
- test_spk, test_mem = net(test_data.view(batch_size, -1))
-
- # 计算测试损失
- test_loss = torch.zeros((1), dtype=dtype, device=device)
- for step in range(num_steps):
- test_loss += loss(test_mem[step], test_targets)
- test_loss_hist.append(test_loss.item())
-
- # 计算测试准确率
- _, idx = test_spk.sum(dim=0).max(1)
- test_acc = np.mean((test_targets == idx).detach().cpu().numpy())
- test_accuracy_hist.append(test_acc)
-
- # 计算并记录训练准确率
- output, _ = net(data.view(batch_size, -1))
- _, idx = output.sum(dim=0).max(1)
- train_acc = np.mean((targets == idx).detach().cpu().numpy())
- train_accuracy_hist.append(train_acc)
-
- # 打印训练信息
- print(f"Epoch {epoch}, Iteration {iter_counter}")
- print(f"Train Set Loss: {loss_hist[-1]:.2f}, Accuracy: {train_acc*100:.2f}%")
- print(f"Test Set Loss: {test_loss_hist[-1]:.2f}, Accuracy: {test_acc*100:.2f}%")
- print("\n")
-
- iter_counter += 1
-
- # 绘制训练和测试准确率曲线
- fig = plt.figure(facecolor="w", figsize=(10, 5))
- plt.plot(train_accuracy_hist, label="Train Accuracy")
- plt.plot(test_accuracy_hist, label="Test Accuracy")
- plt.title("Accuracy Curves")
- plt.legend()
- plt.xlabel("Iteration")
- plt.ylabel("Accuracy"
- total = 0
- correct = 0
-
- # drop_last switched to False to keep all samples
- test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)
-
- with torch.no_grad():
- net.eval()
- for data, targets in test_loader:
- data = data.to(device)
- targets = targets.to(device)
-
- # forward pass
- test_spk, _ = net(data.view(data.size(0), -1))
-
- # calculate total accuracy
- _, predicted = test_spk.sum(dim=0).max(1)
- total += targets.size(0)
- correct += (predicted == targets).sum().item()
- print(f"Total correctly classified test set images: {correct}/{total}")
- print(f"Test Set Accuracy: {100 * correct / total:.2f}%")
输出结果: Total correctly classified test set images: 9432/10000 Test Set Accuracy: 94.32%
- '''1.导入库和模块:'''
- import snntorch as snn
- from snntorch import spikeplot as splt
- from snntorch import spikegen
-
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- from torchvision import datasets, transforms
-
- import matplotlib.pyplot as plt
- import numpy as np
- import itertools
-
- '''2.数据加载和预处理:'''
- # dataloader arguments
- batch_size = 128
- data_path='/data/mnist'
-
- dtype = torch.float
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
- print(torch.cuda.is_available())
-
- # Define a transform
- transform = transforms.Compose([
- transforms.Resize((28, 28)),
- transforms.Grayscale(),
- transforms.ToTensor(),
- transforms.Normalize((0,), (1,))])
-
- mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
- mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
-
- '''3.定义神经网络结构'''
- # Create DataLoaders
- train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
- test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
-
- # Network Architecture
- num_inputs = 28*28
- num_hidden = 1000
- num_outputs = 10
-
- # Temporal Dynamics
- num_steps = 25
- beta = 0.95
-
- # Define Network
- class Net(nn.Module):
- def __init__(self):
- super().__init__()
-
- # Initialize layers
- self.fc1 = nn.Linear(num_inputs, num_hidden)
- self.lif1 = snn.Leaky(beta=beta)
- self.fc2 = nn.Linear(num_hidden, num_outputs)
- self.lif2 = snn.Leaky(beta=beta)
-
- def forward(self, x):
-
- # Initialize hidden states at t=0
- mem1 = self.lif1.init_leaky()
- mem2 = self.lif2.init_leaky()
-
- # Record the final layer
- spk2_rec = []
- mem2_rec = []
-
- for step in range(num_steps):
- cur1 = self.fc1(x)
- spk1, mem1 = self.lif1(cur1, mem1)
- cur2 = self.fc2(spk1)
- spk2, mem2 = self.lif2(cur2, mem2)
- spk2_rec.append(spk2)
- mem2_rec.append(mem2)
-
- return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
-
- # Load the network onto CUDA if available
- net = Net().to(device)
-
- '''4.训练和测试:'''
-
- '''初始化网络和优化器:网络结构被定义为 Net 类,并加载到设备(CPU 或 GPU)上。
- 优化器选择了 Adam 优化器,学习率为 lr=5e-4,动量参数为 (0.9, 0.999)。
- '''
- loss = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
-
- num_epochs = 1
-
- # 记录训练和测试损失的列表
- loss_hist = []
- test_loss_hist = []
-
- # 记录训练和测试准确率的列表
- train_accuracy_hist = []
- test_accuracy_hist = []
-
- # 外部训练循环
- for epoch in range(num_epochs):
- iter_counter = 0
- train_batch = iter(train_loader)
-
- # 内部训练循环
- for data, targets in train_batch:
- data = data.to(device)
- targets = targets.to(device)
-
- # 前向传播
- net.train()
- spk_rec, mem_rec = net(data.view(batch_size, -1))
-
- # 初始化损失并进行时间步长的累加
- loss_val = torch.zeros((1), dtype=dtype, device=device)
- for step in range(num_steps):
- loss_val += loss(mem_rec[step], targets)
-
- # 梯度计算 + 权重更新
- optimizer.zero_grad()
- loss_val.backward()
- optimizer.step()
-
- # 记录训练损失
- loss_hist.append(loss_val.item())
-
- # 在每个迭代的末尾,使用测试集计算测试损失和准确率
- if iter_counter % 50 == 0:
- with torch.no_grad():
- net.eval()
- test_data, test_targets = next(iter(test_loader))
- test_data = test_data.to(device)
- test_targets = test_targets.to(device)
-
- # 测试集前向传播
- test_spk, test_mem = net(test_data.view(batch_size, -1))
-
- # 计算测试损失
- test_loss = torch.zeros((1), dtype=dtype, device=device)
- for step in range(num_steps):
- test_loss += loss(test_mem[step], test_targets)
- test_loss_hist.append(test_loss.item())
-
- # 计算测试准确率
- _, idx = test_spk.sum(dim=0).max(1)
- test_acc = np.mean((test_targets == idx).detach().cpu().numpy())
- test_accuracy_hist.append(test_acc)
-
- # 计算并记录训练准确率
- output, _ = net(data.view(batch_size, -1))
- _, idx = output.sum(dim=0).max(1)
- train_acc = np.mean((targets == idx).detach().cpu().numpy())
- train_accuracy_hist.append(train_acc)
-
- # 打印训练信息
- print(f"Epoch {epoch}, Iteration {iter_counter}")
- print(f"Train Set Loss: {loss_hist[-1]:.2f}, Accuracy: {train_acc*100:.2f}%")
- print(f"Test Set Loss: {test_loss_hist[-1]:.2f}, Accuracy: {test_acc*100:.2f}%")
- print("\n")
-
- iter_counter += 1
-
- # 绘制训练和测试准确率曲线
- fig = plt.figure(facecolor="w", figsize=(10, 5))
- plt.plot(train_accuracy_hist, label="Train Accuracy")
- plt.plot(test_accuracy_hist, label="Test Accuracy")
- plt.title("Accuracy Curves")
- plt.legend()
- plt.xlabel("Iteration")
- plt.ylabel("Accuracy")
- plt.show()
-
- '''5.在测试集上测试'''
- total = 0
- correct = 0
-
- # drop_last switched to False to keep all samples
- test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)
-
- with torch.no_grad():
- net.eval()
- for data, targets in test_loader:
- data = data.to(device)
- targets = targets.to(device)
-
- # forward pass
- test_spk, _ = net(data.view(data.size(0), -1))
-
- # calculate total accuracy
- _, predicted = test_spk.sum(dim=0).max(1)
- total += targets.size(0)
- correct += (predicted == targets).sum().item()
- print(f"Total correctly classified test set images: {correct}/{total}")
- print(f"Test Set Accuracy: {100 * correct / total:.2f}%")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。