赞
踩
上篇讲解了如何结合Pytorch和snntorch搭建一个脉冲神经网络,并演示了前向通道运行过程。但是一个网络不经过训练是没有意义的,脉冲神经网络的训练方法有很多种,包括无监督学习的突触可塑性、有监督学习的梯度下降法。snntorch里面采用梯度下降法对脉冲神经网络进行训练。
在之前的教程中,我们推导出LIF神经元的模型可表述为:
这实际上是一种类似于循环神经网络(RNN)的递归结构,这种结构更适用于处理序列数据。一个脉冲神经元的展开图如下图所示(注意这不是一个脉冲神经网络,只是一个神经元,只不过按时间展开了),横轴是模拟的时间,使用-Uthr代表复位机制,β代表连接权重,U[t]代表输入,S[t]代表输出。传统的 RNN 将 β 作为可学习参数, SNN 默认情况下将其视为超参数,使用超参数搜索取代了梯度消失和梯度爆炸问题。未来的教程将介绍如何将 β 作为可学习参数。
对于输入U[t],输出S[t],复位机制-Uthr,则有
其中 Θ(⋅) 是阶跃函数: 当U[t]-Uthr大于阈值时,S[t]产生脉冲,否则静默。
我们已经得到脉冲神经元的输入U[t]和输出S[t]的关系是一个阶跃函数,此时S和U的一阶导数则为0和无穷大(在脉冲上升时刻),如下图所示。梯度下降法是利用损失相对于权重的梯度来训练网络,从而更新权重,使损失最小化。S和U的导数(梯度)为一个脉冲函数,这种情况下,要不权重不更新,要不权重就直接饱和,无法进行学习。这就是所谓的死神经元问题。
解决 "死神经元 "问题的最常见方法是在前向传递过程中保持阶跃函数的原样,但在后向传递时,将S与U的导数换成过程中不会扼杀学习过程的S`和U导数项 。这听起来可能有些奇怪,但事实证明,神经网络对这种近似是相当稳健的。这就是通常所说的代梯度法。snnTorch 的默认方法(截至 v0.6.0)是使用反正切函数平滑阶跃函数。使用的导数为
之前的导数等式只计算了一个时间步的梯度,但通过时间反向传播(BPTT)算法会计算从损失L到所有时间步长t的梯度,并将它们相加。权重 W应用于每个时间步长t,因此可以想象,损失L也是在每个时间步长计算的。权重对当前损失和历史损失的影响必须相加,以确定总体梯度:
举例来说,W[t-1] 对损失的影响可以写成:
对于单个神经元来说,时间反向传播算法如下图,这里省略了复位机制,在 snnTorch 中,复位包含在前向传递中,但从后向传递中分离出来。
在传统的非脉冲神经网络中,有监督的多类分类问题会选择激活度最高的神经元,并将其作为预测类别。在脉冲神经网络中,有几种解释输出脉冲的方法。最常见的方法有:
- 速率编码:将发射率(或尖峰计数)最高的神经元作为预测类别
- 延迟编码:将最先触发的神经元作为预测类别
这可能与关于神经编码的教程 1 有相似之处。不同之处在于,在这里我们是解释(解码)输出脉冲,而不是将原始输入数据编码/转换成脉冲。
我们主要使用速率编码。当输入数据传递给网络时,我们希望正确的神经元类在模拟运行过程中发出最多的脉冲。这相当于最高的平均发射频率。实现这一目标的方法之一是将正确类神经元的膜电位增加到 U>Uthr ,而将错误类神经元的膜电位调整到 U<Uthr 。
这可以通过提取输出神经元膜电位的软最大值来实现,其中C是输出类的数量:
其实际效果是,正确类别的膜电位被鼓励增加,而不正确类别的膜电位则被降低。实际上,这意味着在所有时间步骤中,正确类别都会被鼓励发射,而不正确类别则会在所有步骤中被抑制。这可能不是最有效的 SNN 实现方法,但却是最简单的方法之一。 该损失适用于模拟的每个时间步长,然后在模拟结束时将这些损失相加:
本文借助snntorch库搭建三层脉冲神经网络,并在mnist数据集上进行训练。以下是每个包的作用:
snntorch 和其子模块:
snntorch
: snntorch 是一个用于构建和训练脉冲神经网络(SNN)的库。它提供了一系列工具和模块,包括神经元模型、脉冲生成器、可视化工具等。snntorch.spikeplot
: 这是 snntorch 提供的用于绘制脉冲数据的模块。snntorch.spikegen
: 包含用于生成脉冲信号的工具。torch 和其子模块:
torch
: PyTorch 是一个用于深度学习的开源机器学习库,提供了张量计算、自动微分等功能。torch.nn
: PyTorch 的神经网络模块,用于构建神经网络层和模型。torch.utils.data
: 提供用于处理数据加载和处理的工具,包括DataLoader
类。datasets 和 transforms:
torchvision.datasets
: 包含用于加载常见视觉数据集的工具。torchvision.transforms
: 提供对图像进行转换和预处理的工具,例如调整大小、转换为灰度、标准化等。matplotlib 和 numpy:
matplotlib.pyplot
: 用于绘制图表和可视化的库。numpy
: 提供对多维数组进行高效操作的库。PyTorch 提供模型架构,snntorch 提供脉冲神经模型,两者结合实现脉冲神经网络的训练和推断。
- import snntorch as snn
- from snntorch import spikeplot as splt
- from snntorch import spikegen
-
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- from torchvision import datasets, transforms
-
- import matplotlib.pyplot as plt
- import numpy as np
- import itertools
首先加载mnist数据集,初次运行会下载该数据集,并保存到同目录下的data文件夹中(data文件夹是自动创建的),若要修改保存路径,修改data_path变量值即可。
之后创建一个transform,对输入数据进行预处理。
transform
被定义为一个包含多个转换的Compose
对象,这些转换将应用于输入的图像数据。具体来说,这里的转换包括:
transforms.Resize((28, 28))
: 将图像大小调整为 28x28 像素。transforms.Grayscale()
: 将图像转换为灰度图。transforms.ToTensor()
: 将图像转换为 PyTorch 张量。transforms.Normalize((0,), (1,))
: 对图像进行标准化,将像素值从 [0, 1] 缩放到均值为 0、标准差为 1。这些转换的目的是将输入的手写数字图像转换为网络训练所需的格式。例如,将图像大小调整为统一的大小,将图像转为灰度以简化处理,将图像转换为张量以在 PyTorch 中进行处理,最后进行标准化以提高训练的稳定性。
- # dataloader的参数
- batch_size = 128
- data_path='/data/mnist'
-
- dtype = torch.float
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
- print(torch.cuda.is_available())
-
- # 定义transform
- transform = transforms.Compose([
- transforms.Resize((28, 28)),
- transforms.Grayscale(),
- transforms.ToTensor(),
- transforms.Normalize((0,), (1,))])
-
- mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
- mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
只有在输入参数
x
被明确传入net
后,才会调用forward()
函数中的代码。
fc1
-全连接层:对来自 MNIST 数据集的所有输入像素进行线性变换;
lif1-脉冲神经层:
在一段时间内对加权输入进行积分,如果满足阈值条件,则发出脉冲;
fc2
-全连接层:对lif1
的输出脉冲进行线性变换;
lif2-脉冲神经层:
对加权脉冲进行时间积分。
- # 创建训练和测试用的DataLoaders
- train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
- test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
-
- # 网络结构
- num_inputs = 28*28
- num_hidden = 1000
- num_outputs = 10
-
- # 时间参数
- num_steps = 25
- beta = 0.95
-
- # 定义网络结构
- class Net(nn.Module):
- def __init__(self):
- super().__init__()
-
- # Initialize layers
- self.fc1 = nn.Linear(num_inputs, num_hidden)
- self.lif1 = snn.Leaky(beta=beta)
- self.fc2 = nn.Linear(num_hidden, num_outputs)
- self.lif2 = snn.Leaky(beta=beta)
-
- def forward(self, x):
-
- # Initialize hidden states at t=0
- mem1 = self.lif1.init_leaky()
- mem2 = self.lif2.init_leaky()
-
- # 记录输出层的脉冲
- spk2_rec = []
- mem2_rec = []
-
- for step in range(num_steps):
- cur1 = self.fc1(x)
- spk1, mem1 = self.lif1(cur1, mem1)
- cur2 = self.fc2(spk1)
- spk2, mem2 = self.lif2(cur2, mem2)
- spk2_rec.append(spk2)
- mem2_rec.append(mem2)
-
- return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
-
- # Load the network onto CUDA if available
- net = Net().to(device)
下面是一个函数,它获取一批数据,对每个神经元的所有尖峰进行计数(即模拟时间内的速率代码),并将最高计数的索引与实际目标进行比较。如果两者匹配,则说明网络正确预测了目标。
- # pass data into the network, sum the spikes over time
- # and compare the neuron with the highest number of spikes
- # with the target
-
- def print_batch_accuracy(data, targets, train=False):
- output, _ = net(data.view(batch_size, -1))
- _, idx = output.sum(dim=0).max(1)
- acc = np.mean((targets == idx).detach().cpu().numpy())
-
- if train:
- print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
- else:
- print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")
-
- def train_printer():
- print(f"Epoch {epoch}, Iteration {iter_counter}")
- print(f"Train Set Loss: {loss_hist[counter]:.2f}")
- print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
- print_batch_accuracy(data, targets, train=True)
- print_batch_accuracy(test_data, test_targets, train=False)
- print("\n")
使用脉冲神经网络(SNN)进行前向和反向传播。
- '''初始化网络和优化器:网络结构被定义为 Net 类,并加载到设备(CPU 或 GPU)上。
- 优化器选择了 Adam 优化器,学习率为 lr=5e-4,动量参数为 (0.9, 0.999)。
- '''
- loss = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))
-
- num_epochs = 1
- import snntorch as snn
- import torch
- import torch.nn as nn
- from torch.utils.data import DataLoader
- from torchvision import datasets, transforms
- import matplotlib.pyplot as plt
- import numpy as np
-
- # ... (之前的代码)
-
- # 记录训练和测试损失的列表
- loss_hist = []
- test_loss_hist = []
-
- # 记录训练和测试准确率的列表
- train_accuracy_hist = []
- test_accuracy_hist = []
-
- # 外部训练循环
- for epoch in range(num_epochs):
- iter_counter = 0
- train_batch = iter(train_loader)
-
- # 内部训练循环
- for data, targets in train_batch:
- data = data.to(device)
- targets = targets.to(device)
-
- # 前向传播
- net.train()
- spk_rec, mem_rec = net(data.view(batch_size, -1))
-
- # 初始化损失并进行时间步长的累加
- loss_val = torch.zeros((1), dtype=dtype, device=device)
- for step in range(num_steps):
- loss_val += loss(mem_rec[step], targets)
-
- # 梯度计算 + 权重更新
- optimizer.zero_grad()
- loss_val.backward()
- optimizer.step()
-
- # 记录训练损失
- loss_hist.append(loss_val.item())
-
- # 在每个迭代的末尾,使用测试集计算测试损失和准确率
- if iter_counter % 50 == 0:
- with torch.no_grad():
- net.eval()
- test_data, test_targets = next(iter(test_loader))
- test_data = test_data.to(device)
- test_targets = test_targets.to(device)
-
- # 测试集前向传播
- test_spk, test_mem = net(test_data.view(batch_size, -1))
-
- # 计算测试损失
- test_loss = torch.zeros((1), dtype=dtype, device=device)
- for step in range(num_steps):
- test_loss += loss(test_mem[step], test_targets)
- test_loss_hist.append(test_loss.item())
-
- # 计算测试准确率
- _, idx = test_spk.sum(dim=0).max(1)
- test_acc = np.mean((test_targets == idx).detach().cpu().numpy())
- test_accuracy_hist.append(test_acc)
-
- # 计算并记录训练准确率
- output, _ = net(data.view(batch_size, -1))
- _, idx = output.sum(dim=0).max(1)
- train_acc = np.mean((targets == idx).detach().cpu().numpy())
- train_accuracy_hist.append(train_acc)
-
- # 打印训练信息
- print(f"Epoch {epoch}, Iteration {iter_counter}")
- print(f"Train Set Loss: {loss_hist[-1]:.2f}, Accuracy: {train_acc*100:.2f}%")
- print(f"Test Set Loss: {test_loss_hist[-1]:.2f}, Accuracy: {test_acc*100:.2f}%")
- print("\n")
-
- iter_counter += 1
-
- # 绘制训练和测试准确率曲线
- fig = plt.figure(facecolor="w", figsize=(10, 5))
- plt.plot(train_accuracy_hist, label="Train Accuracy")
- plt.plot(test_accuracy_hist, label="Test Accuracy")
- plt.title("Accuracy Curves")
- plt.legend()
- plt.xlabel("Iteration")
- plt.ylabel("Accuracy"
- total = 0
- correct = 0
-
- # drop_last switched to False to keep all samples
- test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=False)
-
- with torch.no_grad():
- net.eval()
- for data, targets in test_loader:
- data = data.to(device)
- targets = targets.to(device)
-
- # forward pass
- test_spk, _ = net(data.view(data.size(0), -1))
-
- # calculate total accuracy
- _, predicted = test_spk.sum(dim=0).max(1)
- total += targets.size(0)
- correct += (predicted == targets).sum().item()
- print(f"Total correctly classified test set images: {correct}/{total}")
- print(f"Test Set Accuracy: {100 * correct / total:.2f}%")
输出结果: Total correctly classified test set images: 9432/10000 Test Set Accuracy: 94.32%
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。