当前位置:   article > 正文

【脉冲神经网络教程 03】脉冲神经网络搭建及训练-理论篇_snntorch

snntorch

本文讲解了基于snntorch库搭建脉冲神经网络的过程,解释了前向模型的原理,并进行了代码仿真(文末有完整代码可以直接运行),基于梯度下降法的脉冲神经网络训练过程则在第四节讲解。

1. LIF 神经元模型

关于LIF神经元模型的详细解释可以参考官网文档snntorch 0.7.0  --- Tutorial 2 - The Leaky Integrate-and-Fire Neuron — snntorch 0.7.0 documentation

2. 简化的LIF 神经元模型

 LIF 神经元模型相当复杂,需要调整一系列超参数,这导致需要跟踪的参数非常多,如果扩展到完整的 SNN,就会变得更加繁琐。因此,我们不妨做一些简化。

在前面的教程中,我们使用欧拉法推导出了被动膜模型的以下解法:

假设

如果我们假设 t 代表序列中的时间步长而不是连续时间,那么我们可以设置 Δt=1 。为了进一步减少超参数的数量,可以假设 R=1 ,则有

在深度学习中,输入的权重因子通常是一个可学习的参数。因此引入输入X[t]:

最后可以得到:

 

考虑膜的复位机制,如果膜超过阈值,神经元就会发出输出尖峰,如果触发了尖峰,膜电位应该复位。重置-减弱机制的模型是:

W 是一个可学习的参数,而 Uthr 通常只是设置为 1 (尽管可以调整),因此衰减率β是唯一需要指定的超参数。 

输入阶跃信号,神经元响应使用代码实现如下:

  1. # imports
  2. import snntorch as snn
  3. from snntorch import spikeplot as splt
  4. from snntorch import spikegen
  5. import torch
  6. import torch.nn as nn
  7. import matplotlib.pyplot as plt
  8. def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
  9. spk = (mem > threshold) # if membrane exceeds threshold, spk=1, else, 0
  10. mem = beta * mem + w*x - spk*threshold
  11. return spk, mem
  12. #@title Plotting Settings
  13. def plot_cur_mem_spk(cur, mem, spk, thr_line=False, vline=False, title=False, ylim_max1=1.25, ylim_max2=1.25):
  14. # Generate Plots
  15. fig, ax = plt.subplots(3, figsize=(8,6), sharex=True,
  16. gridspec_kw = {'height_ratios': [1, 1, 0.4]})
  17. # Plot input current
  18. ax[0].plot(cur, c="tab:orange")
  19. ax[0].set_ylim([0, ylim_max1])
  20. ax[0].set_xlim([0, 200])
  21. ax[0].set_ylabel("Input Current ($I_{in}$)")
  22. if title:
  23. ax[0].set_title(title)
  24. # Plot membrane potential
  25. ax[1].plot(mem)
  26. ax[1].set_ylim([0, ylim_max2])
  27. ax[1].set_ylabel("Membrane Potential ($U_{mem}$)")
  28. if thr_line:
  29. ax[1].axhline(y=thr_line, alpha=0.25, linestyle="dashed", c="black", linewidth=2)
  30. plt.xlabel("Time step")
  31. # Plot output spike using spikeplot
  32. splt.raster(spk, ax[2], s=400, c="black", marker="|")
  33. if vline:
  34. ax[2].axvline(x=vline, ymin=0, ymax=6.75, alpha = 0.15, linestyle="dashed", c="black", linewidth=2, zorder=0, clip_on=False)
  35. plt.ylabel("Output spikes")
  36. plt.yticks([])
  37. plt.show()
  38. # set neuronal parameters
  39. delta_t = torch.tensor(1e-3)
  40. tau = torch.tensor(5e-3)
  41. beta = torch.exp(-delta_t/tau)
  42. print(f"The decay rate is: {beta:.3f}")
  43. num_steps = 200
  44. # initialize inputs/outputs + small step current input
  45. x = torch.cat((torch.zeros(10), torch.ones(190)*0.5), 0)
  46. mem = torch.zeros(1)
  47. spk_out = torch.zeros(1)
  48. mem_rec = []
  49. spk_rec = []
  50. # neuron parameters
  51. w = 0.4
  52. beta = 0.819
  53. # neuron simulation
  54. for step in range(num_steps):
  55. spk, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)
  56. mem_rec.append(mem)
  57. spk_rec.append(spk)
  58. # convert lists to tensors
  59. mem_rec = torch.stack(mem_rec)
  60. spk_rec = torch.stack(spk_rec)
  61. plot_cur_mem_spk(x*w, mem_rec, spk_rec, thr_line=1,ylim_max1=0.5,
  62. title="LIF Neuron Model With Weighted Step Voltage")

3. snntorch实现LIF神经元模型

在snntorch中,通过实例化 snn.Leaky可以实现以上公式,如:

lif1 = snn.Leaky(beta=0.8)

神经元模型现在存储在 lif1 中, lif1 的输入输出如下:

输入

cur_in:每个元素作为输入依次传递

mem:上一步的膜电位U[t]也作为输入

输出

spk_out :输出尖峰

mem:本次时间步的膜电位

同样使用阶跃信号,snntorch构造的神经元模型响应使用以下代码实现:

  1. # 定义神经元
  2. lif1 = snn.Leaky(beta=0.8)
  3. # Small step current input
  4. w=0.21
  5. cur_in = torch.cat((torch.zeros(10), torch.ones(190)*w), 0)
  6. mem = torch.zeros(1)
  7. spk = torch.zeros(1)
  8. mem_rec = []
  9. spk_rec = []
  10. # 模拟神经元
  11. for step in range(num_steps):
  12. spk, mem = lif1(cur_in[step], mem)
  13. mem_rec.append(mem)
  14. spk_rec.append(spk)
  15. # convert lists to tensors
  16. mem_rec = torch.stack(mem_rec)
  17. spk_rec = torch.stack(spk_rec)
  18. plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, ylim_max1=0.5,
  19. title="snn.Leaky Neuron Model")

4. 前馈脉冲神经网络

到目前为止,我们只考虑了单个神经元如何对输入刺激做出响应。snnTorch 可以直接将其扩展为深度神经网络。在本节中,我们将创建一个维数为 784-1000-10 的 3 层全连接神经网络。

PyTorch 用于形成神经元之间的连接,而 snnTorch 则用于创建神经元。首先,初始化所有层。

  1. # layer parameters
  2. num_inputs = 784
  3. num_hidden = 1000
  4. num_outputs = 10
  5. beta = 0.99
  6. # initialize layers
  7. fc1 = nn.Linear(num_inputs, num_hidden)
  8. lif1 = snn.Leaky(beta=beta)
  9. fc2 = nn.Linear(num_hidden, num_outputs)
  10. lif2 = snn.Leaky(beta=beta)

接下来,初始化每个尖峰神经元的隐藏变量和输出。随着网络规模的扩大,这项工作会变得越来越繁琐。静态方法 init_leaky() 可以解决这个问题 。在第一次前向传递时,隐藏状态的形状会根据输入数据的维度自动初始化。

  1. # Initialize hidden states
  2. mem1 = lif1.init_leaky()
  3. mem2 = lif2.init_leaky()
  4. # record outputs
  5. mem2_rec = []
  6. spk1_rec = []
  7. spk2_rec = []

创建输入脉冲序列,并将其传递给网络。有 200 个时间步来模拟 784 个输入神经元,即输入的原始维度为 200×784 。 然而,神经网络通常以小批次处理数据:

spk_in = spikegen.rate_conv(torch.rand((200, 784))).unsqueeze(1)

 整个前向过程如下:

  • 权重W初始化,200*784的脉冲序列spk_in依次 输入到输入层,一开始输入784个脉冲

  • 这就产生了中的输入电流项,对脉冲神经元的输出U[t+1]起到了作用

  • 如果U[t+1]大于阈值 ,则从该神经元触发一个脉冲

  • 该尖峰由第二层权重加权,然后对所有输入、权重和神经元重复上述过程。

  • 如果没有脉冲,突触后神经元就不会收到任何信息。

  1. # layer parameters
  2. num_inputs = 784
  3. num_hidden = 1000
  4. num_outputs = 10
  5. beta = 0.99
  6. # initialize layers
  7. fc1 = nn.Linear(num_inputs, num_hidden)
  8. lif1 = snn.Leaky(beta=beta)
  9. fc2 = nn.Linear(num_hidden, num_outputs)
  10. lif2 = snn.Leaky(beta=beta)
  11. # Initialize hidden states
  12. mem1 = lif1.init_leaky()
  13. mem2 = lif2.init_leaky()
  14. # record outputs
  15. mem2_rec = []
  16. spk1_rec = []
  17. spk2_rec = []
  18. spk_in = spikegen.rate_conv(torch.rand((200, 784))).unsqueeze(1)
  19. # network simulation
  20. for step in range(num_steps):
  21. cur1 = fc1(spk_in[step]) # post-synaptic current <-- spk_in x weight
  22. spk1, mem1 = lif1(cur1, mem1) # mem[t+1] <--post-syn current + decayed membrane
  23. cur2 = fc2(spk1)
  24. spk2, mem2 = lif2(cur2, mem2)
  25. mem2_rec.append(mem2)
  26. spk1_rec.append(spk1)
  27. spk2_rec.append(spk2)
  28. # convert lists to tensors
  29. mem2_rec = torch.stack(mem2_rec)
  30. spk1_rec = torch.stack(spk1_rec)
  31. spk2_rec = torch.stack(spk2_rec)
  32. def plot_snn_spikes(spk_in, spk1_rec, spk2_rec, title):
  33. # Generate Plots
  34. fig, ax = plt.subplots(3, figsize=(8,7), sharex=True,
  35. gridspec_kw = {'height_ratios': [1, 1, 0.4]})
  36. # Plot input spikes
  37. splt.raster(spk_in[:,0], ax[0], s=0.03, c="black")
  38. ax[0].set_ylabel("Input Spikes")
  39. ax[0].set_title(title)
  40. # Plot hidden layer spikes
  41. splt.raster(spk1_rec.reshape(num_steps, -1), ax[1], s = 0.05, c="black")
  42. ax[1].set_ylabel("Hidden Layer")
  43. # Plot output spikes
  44. splt.raster(spk2_rec.reshape(num_steps, -1), ax[2], c="black", marker="|")
  45. ax[2].set_ylabel("Output Spikes")
  46. ax[2].set_ylim([0, 10])
  47. plt.show()
  48. plot_snn_spikes(spk_in, spk1_rec, spk2_rec, "Fully Connected Spiking Neural Network")

输出依次为输入层,隐藏层和输出层。 

完整代码

  1. # imports
  2. import snntorch as snn
  3. from snntorch import spikeplot as splt
  4. from snntorch import spikegen
  5. import torch
  6. import torch.nn as nn
  7. import matplotlib.pyplot as plt
  8. def leaky_integrate_and_fire(mem, x, w, beta, threshold=1):
  9. spk = (mem > threshold) # if membrane exceeds threshold, spk=1, else, 0
  10. mem = beta * mem + w*x - spk*threshold
  11. return spk, mem
  12. #@title Plotting Settings
  13. def plot_cur_mem_spk(cur, mem, spk, thr_line=False, vline=False, title=False, ylim_max1=1.25, ylim_max2=1.25):
  14. # Generate Plots
  15. fig, ax = plt.subplots(3, figsize=(8,6), sharex=True,
  16. gridspec_kw = {'height_ratios': [1, 1, 0.4]})
  17. # Plot input current
  18. ax[0].plot(cur, c="tab:orange")
  19. ax[0].set_ylim([0, ylim_max1])
  20. ax[0].set_xlim([0, 200])
  21. ax[0].set_ylabel("Input Current ($I_{in}$)")
  22. if title:
  23. ax[0].set_title(title)
  24. # Plot membrane potential
  25. ax[1].plot(mem)
  26. ax[1].set_ylim([0, ylim_max2])
  27. ax[1].set_ylabel("Membrane Potential ($U_{mem}$)")
  28. if thr_line:
  29. ax[1].axhline(y=thr_line, alpha=0.25, linestyle="dashed", c="black", linewidth=2)
  30. plt.xlabel("Time step")
  31. # Plot output spike using spikeplot
  32. splt.raster(spk, ax[2], s=400, c="black", marker="|")
  33. if vline:
  34. ax[2].axvline(x=vline, ymin=0, ymax=6.75, alpha = 0.15, linestyle="dashed", c="black", linewidth=2, zorder=0, clip_on=False)
  35. plt.ylabel("Output spikes")
  36. plt.yticks([])
  37. plt.show()
  38. def plot_snn_spikes(spk_in, spk1_rec, spk2_rec, title):
  39. # Generate Plots
  40. fig, ax = plt.subplots(3, figsize=(8,7), sharex=True,
  41. gridspec_kw = {'height_ratios': [1, 1, 0.4]})
  42. # Plot input spikes
  43. splt.raster(spk_in[:,0], ax[0], s=0.03, c="black")
  44. ax[0].set_ylabel("Input Spikes")
  45. ax[0].set_title(title)
  46. # Plot hidden layer spikes
  47. splt.raster(spk1_rec.reshape(num_steps, -1), ax[1], s = 0.05, c="black")
  48. ax[1].set_ylabel("Hidden Layer")
  49. # Plot output spikes
  50. splt.raster(spk2_rec.reshape(num_steps, -1), ax[2], c="black", marker="|")
  51. ax[2].set_ylabel("Output Spikes")
  52. ax[2].set_ylim([0, 10])
  53. plt.show()
  54. # set neuronal parameters
  55. delta_t = torch.tensor(1e-3)
  56. tau = torch.tensor(5e-3)
  57. beta = torch.exp(-delta_t/tau)
  58. print(f"The decay rate is: {beta:.3f}")
  59. num_steps = 200
  60. # initialize inputs/outputs + small step current input
  61. x = torch.cat((torch.zeros(10), torch.ones(190)*0.5), 0)
  62. mem = torch.zeros(1)
  63. spk_out = torch.zeros(1)
  64. mem_rec = []
  65. spk_rec = []
  66. # neuron parameters
  67. w = 0.4
  68. beta = 0.819
  69. # neuron simulation
  70. for step in range(num_steps):
  71. spk, mem = leaky_integrate_and_fire(mem, x[step], w=w, beta=beta)
  72. mem_rec.append(mem)
  73. spk_rec.append(spk)
  74. # convert lists to tensors
  75. mem_rec = torch.stack(mem_rec)
  76. spk_rec = torch.stack(spk_rec)
  77. plot_cur_mem_spk(x*w, mem_rec, spk_rec, thr_line=1,ylim_max1=0.5,
  78. title="LIF Neuron Model With Weighted Step Voltage")
  79. lif1 = snn.Leaky(beta=0.8)
  80. # Small step current input
  81. w=0.21
  82. cur_in = torch.cat((torch.zeros(10), torch.ones(190)*w), 0)
  83. mem = torch.zeros(1)
  84. spk = torch.zeros(1)
  85. mem_rec = []
  86. spk_rec = []
  87. # neuron simulation
  88. for step in range(num_steps):
  89. spk, mem = lif1(cur_in[step], mem)
  90. mem_rec.append(mem)
  91. spk_rec.append(spk)
  92. # convert lists to tensors
  93. mem_rec = torch.stack(mem_rec)
  94. spk_rec = torch.stack(spk_rec)
  95. plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, ylim_max1=0.5,
  96. title="snn.Leaky Neuron Model")
  97. lif1 = snn.Leaky(beta=0.8)
  98. # Small step current input
  99. w=0.21
  100. cur_in = torch.cat((torch.zeros(10), torch.ones(190)*w), 0)
  101. mem = torch.zeros(1)
  102. spk = torch.zeros(1)
  103. mem_rec = []
  104. spk_rec = []
  105. # neuron simulation
  106. for step in range(num_steps):
  107. spk, mem = lif1(cur_in[step], mem)
  108. mem_rec.append(mem)
  109. spk_rec.append(spk)
  110. # convert lists to tensors
  111. mem_rec = torch.stack(mem_rec)
  112. spk_rec = torch.stack(spk_rec)
  113. plot_cur_mem_spk(cur_in, mem_rec, spk_rec, thr_line=1, ylim_max1=0.5,
  114. title="snn.Leaky Neuron Model")
  115. # layer parameters
  116. num_inputs = 784
  117. num_hidden = 1000
  118. num_outputs = 10
  119. beta = 0.99
  120. # initialize layers
  121. fc1 = nn.Linear(num_inputs, num_hidden)
  122. lif1 = snn.Leaky(beta=beta)
  123. fc2 = nn.Linear(num_hidden, num_outputs)
  124. lif2 = snn.Leaky(beta=beta)
  125. # Initialize hidden states
  126. mem1 = lif1.init_leaky()
  127. mem2 = lif2.init_leaky()
  128. # record outputs
  129. mem2_rec = []
  130. spk1_rec = []
  131. spk2_rec = []
  132. spk_in = spikegen.rate_conv(torch.rand((200, 784))).unsqueeze(1)
  133. # network simulation
  134. for step in range(num_steps):
  135. cur1 = fc1(spk_in[step]) # post-synaptic current <-- spk_in x weight
  136. spk1, mem1 = lif1(cur1, mem1) # mem[t+1] <--post-syn current + decayed membrane
  137. cur2 = fc2(spk1)
  138. spk2, mem2 = lif2(cur2, mem2)
  139. mem2_rec.append(mem2)
  140. spk1_rec.append(spk1)
  141. spk2_rec.append(spk2)
  142. # convert lists to tensors
  143. mem2_rec = torch.stack(mem2_rec)
  144. spk1_rec = torch.stack(spk1_rec)
  145. spk2_rec = torch.stack(spk2_rec)
  146. plot_snn_spikes(spk_in, spk1_rec, spk2_rec, "Fully Connected Spiking Neural Network")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/647504
推荐阅读
相关标签
  

闽ICP备14008679号