当前位置:   article > 正文

基于SNN脉冲神经网络的Hebbian学习训练过程matlab仿真_matlab snn训练

matlab snn训练

目录

一、理论基础

二、案例背景

1.问题描述

2.思路流程

三、部分MATLAB仿真

四、仿真结论分析

五、参考文献


一、理论基础

        近年来,深度学习彻底改变了机器学习领域,尤其是计算机视觉。在这种方法中,使用反向传播以监督的方式训练深层(多层)人工神经网络(ANN)。虽然需要大量带标签的训练样本,但是最终的分类准确性确实令人印象深刻,有时甚至胜过人类。人工神经网络中的神经元的特征在于单个、静态、连续值的激活。然而生物神经元使用离散的脉冲来计算和传输信息,并且除了脉冲发放率外,脉冲时间也很重要。因此脉冲神经网络(SNN)在生物学上比ANN更现实,并且如果有人想了解大脑的计算方式,它无疑是唯一可行的选择。 SNN也比ANN更具硬件友好性和能源效率,因此对技术,尤其是便携式设备具有吸引力。但是训练深度SNN仍然是一个挑战。脉冲神经元的传递函数通常是不可微的,从而阻止了反向传播。在这里,我们回顾了用于训练深度SNN的最新监督和无监督方法,并在准确性、计算成本和硬件友好性方面进行了比较。目前的情况是,SNN在准确性方面仍落后于ANN,但差距正在缩小,甚至在某些任务上可能消失,而SNN通常只需要更少的操作。

       SNN中的无监督学习通常将STDP纳入学习机制。生物STDP的最常见形式具有非常直观的解释。如果突触前神经元在突触后神经元之前不久触发(大约10毫秒),则连接它们的权重会增加。如果突触后神经元在突触后神经元后不久触发,则时间事件之间的因果关系是虚假的,权重会减弱。增强称为长时程增强(LTP),减弱称为长时程抑制(LTD)。短语“长时程”用于区分实验中观察到的几毫秒范围内的非常短暂的影响。

      下面的公式是通过拟合实验数据对一对脉冲进行了实验上最常见的STDP规则的理想化。
     

       以上公式中的第一种情况描述LTP,而第二种情况描述LTD。效果的强度由衰减指数调制,衰减指数的大小由突触前和突触后脉冲之间的时间常数比例时间差控制。人工SNN很少使用此确切规则。他们通常使用变体,以实现更多的简单性或满足便利的数学特性。

二、案例背景

1.问题描述

        SNN神经网络的学习方法也不是很好,作为传统的基于速率的网络而发展,使用反向传播学习算法。使用高效的Hebbian学习方法:棘突神经元网络的稳态。类似于STDP,尖峰之间的计时用于突触修饰。内稳态确保了突触权重是有界的学习是稳定的。赢家通吃机制也很重要实施以促进输出之间的竞争性学习神经元。我们已经在一个C++对象中实现了这个方法面向对象的代码(称为CSpike)。我们已经在四个服务器上测试了代码Gabor滤波器的图像,并发现钟形调谐曲线使用不同类型的Gabor滤波器的36个测试集图像方向。这些钟形曲线与这些曲线相似实验上观察到的V1和MT/V5区域哺乳动物的大脑。

2.思路流程

SNN即目前的最新的第三代神经网络,具体的仿真步骤如下所示:

 

三、部分MATLAB仿真

matlab仿真程序如下所示:

  1. clc;
  2. clear;
  3. close all;
  4. warning off;
  5. addpath 'func\'
  6. RandStream.setDefaultStream(RandStream('mt19937ar','seed',1));
  7. %%
  8. load Character\Character_set.mat
  9. %显示论文fig1. character set used
  10. func_view_character;
  11. %**************************************************************************
  12. %%
  13. %显示论文fig2. Representation of 'A'
  14. A_Line = Representation_Character(A_ch);
  15. figure;
  16. stem(1:5*3,A_Line,'LineWidth',2);
  17. hold on
  18. stairs(1:5*3,A_Line,'r');
  19. axis([0,5*3+1,-0.25,1.25]);
  20. title('Fig.2. Representation of A');
  21. %**************************************************************************
  22. %% 以下是程序的第一部分,即4个字符的仿真
  23. %建立SNN神经网络模型
  24. Rm = 80;
  25. theta = 10; %10mv
  26. rs = 2; %2ms
  27. rm = 30; %30ms
  28. rmin = 2; %2ms
  29. rmax = 30; %30ms;
  30. lin = 0.3;
  31. lin_dec = 0.05;
  32. A1 = 0.1;
  33. A2 =-0.105;
  34. r1 = 1; %1ms
  35. r2 = 1; %1ms
  36. tstep = 0.2; %0.2ms;
  37. times = 200; %训练次数
  38. error = 1e-3;%训练目标误差
  39. vth = 7;
  40. %通过神经网络对A,B,C,D进行训练识别
  41. %对应论文fig.3. Output when each character is presented individually
  42. %随机产生初始的权值wij
  43. N_in = 15;
  44. w_initial = 0.5+0.5*rand(N_in,4);
  45. for i = 1:times
  46. w{i} = w_initial;
  47. end
  48. w0 = w_initial;
  49. dew = 0;
  50. wmax = 0;
  51. wmin = 0;
  52. det = zeros(N_in,4);
  53. tpre = 300*ones(N_in,times);
  54. tpost = 301*ones(N_in,times);
  55. Time2 = 4000;
  56. dt = 0.05;
  57. STIME = 24;
  58. %%
  59. %字符转换为电平
  60. A_Line = Representation_Character(A_ch);
  61. B_Line = Representation_Character(B_ch);
  62. C_Line = Representation_Character(C_ch);
  63. D_Line = Representation_Character(D_ch);
  64. Lines = [A_Line B_Line C_Line D_Line];
  65. for num = 1:4
  66. Dat = Lines(:,num);
  67. for ij = 1:1
  68. for i = 1:times
  69. w{i} = w_initial;
  70. end
  71. ind = 0;
  72. for i = 1:times
  73. i
  74. ind = ind + 1;
  75. %计算rd
  76. for n1 = 1:N_in
  77. rd(n1) = rmax - abs(w{i}(n1))*(rmax - rmin);
  78. end
  79. %计算Rd
  80. for n1 = 1:N_in
  81. Rd(n1) = (rd(n1)*theta/Rm) * (rm/rd(n1))^(rm/(rm-rd(n1)));
  82. end
  83. %计算delta t
  84. for n1 = 1:N_in
  85. if Dat(n1) == 1
  86. det(n1) = tpre(n1,i) - tpost(n1,i);
  87. else
  88. det(n1) = -(tpre(n1,i) - tpost(n1,i));
  89. end
  90. end
  91. %计算delta w
  92. for n1 = 1:N_in
  93. if det(n1) <= 0
  94. dew(n1) = A1*exp( det(n1)/r1);
  95. else
  96. dew(n1) = A2*exp(-det(n1)/r2);
  97. end
  98. if i > 1
  99. %计算权值更新
  100. if dew(n1) > 0
  101. w{i}(n1) = w{i-1}(n1) + lin*dew(n1)*(w{i-1}(n1));
  102. else
  103. w{i}(n1) = w{i-1}(n1) + lin*dew(n1)*(w{i-1}(n1));
  104. end
  105. end
  106. end
  107. %计算Id
  108. for n1 = 1:N_in
  109. Id{n1} = func_Id(Rd(n1),rd(n1),w{i}(n1),Dat(n1),Time2,dt,tpre(n1,i));
  110. end
  111. %计算Is
  112. Is = func_Is(N_in,rs,w{i},Dat,Time2,dt,tpre(n1,i));
  113. %计算u
  114. Um = func_um(N_in,w{i},Rm,Id,Is,rm,Time2,dt,Dat,tpre(n1,i),vth);
  115. %计算训练误差
  116. if i > 1
  117. err2(ind-1) = abs(norm(w{i}/max(w{i})) - norm(w{i-1}/max(w{i-1})));
  118. if abs(norm(w{i}/max(w{i}) - norm(w{i-1}/max(w{i-1})))) <= error
  119. break;
  120. end
  121. end
  122. end
  123. end
  124. Ws(:,num) = w{end}/max(w{end});
  125. clear Is Id Um w
  126. end
  127. %%
  128. %训练完之后进行测试
  129. %A测试
  130. %A测试
  131. UmA = [];
  132. for ij = 1:STIME
  133. ij
  134. for j = 1:4
  135. ind = 0;
  136. ind = ind + 1;
  137. %计算Id
  138. for n1 = 1:N_in
  139. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,1),Lines(n1,j),Time2,dt,tpre(n1,j));
  140. end
  141. %计算Is
  142. Is{j} = func_Is(N_in,rs,Ws(:,1),Lines(:,j),Time2,dt,tpre(n1,j));
  143. %计算u
  144. Um(j,:) = func_um(N_in,Ws(:,1),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
  145. end
  146. UmA = [UmA,Um];
  147. end
  148. figure;
  149. plot(1:Time2*STIME,UmA(1,:),'r');
  150. hold on
  151. plot(1:Time2*STIME,UmA(2,:),'g');
  152. hold on
  153. plot(1:Time2*STIME,UmA(3,:),'b');
  154. hold on
  155. plot(1:Time2*STIME,UmA(4,:),'c');
  156. hold off
  157. legend('Neuron1','Neuron2','Neuron3','Neuron4');
  158. axis([1,Time2*STIME,0,30]);
  159. clear Id Is Um w
  160. %B测试
  161. %B测试
  162. UmB = [];
  163. for ij = 1:STIME
  164. ij
  165. for j = 1:4
  166. ind = 0;
  167. ind = ind + 1;
  168. %计算Id
  169. for n1 = 1:N_in
  170. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,2),Lines(n1,j),Time2,dt,tpre(n1,j));
  171. end
  172. %计算Is
  173. Is{j} = func_Is(N_in,rs,Ws(:,2),Lines(:,j),Time2,dt,tpre(n1,j));
  174. %计算u
  175. Um(j,:) = func_um(N_in,Ws(:,2),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
  176. end
  177. UmB = [UmB,Um];
  178. end
  179. figure;
  180. plot(1:Time2*STIME,UmB(1,:),'r');
  181. hold on
  182. plot(1:Time2*STIME,UmB(2,:),'g');
  183. hold on
  184. plot(1:Time2*STIME,UmB(3,:),'b');
  185. hold on
  186. plot(1:Time2*STIME,UmB(4,:),'c');
  187. hold off
  188. legend('Neuron1','Neuron2','Neuron3','Neuron4');
  189. axis([1,Time2*STIME,0,30]);
  190. clear Id Is Um w
  191. %C测试
  192. %C测试
  193. UmC = [];
  194. for ij = 1:STIME
  195. ij
  196. for j = 1:4
  197. ind = 0;
  198. ind = ind + 1;
  199. %计算Id
  200. for n1 = 1:N_in
  201. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,3),Lines(n1,j),Time2,dt,tpre(n1,j));
  202. end
  203. %计算Is
  204. Is{j} = func_Is(N_in,rs,Ws(:,3),Lines(:,j),Time2,dt,tpre(n1,j));
  205. %计算u
  206. Um(j,:) = func_um(N_in,Ws(:,3),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
  207. end
  208. UmC = [UmC,Um];
  209. end
  210. figure;
  211. plot(1:Time2*STIME,UmC(1,:),'r');
  212. hold on
  213. plot(1:Time2*STIME,UmC(2,:),'g');
  214. hold on
  215. plot(1:Time2*STIME,UmC(3,:),'b');
  216. hold on
  217. plot(1:Time2*STIME,UmC(4,:),'c');
  218. hold off
  219. legend('Neuron1','Neuron2','Neuron3','Neuron4');
  220. axis([1,Time2*STIME,0,30]);
  221. clear Id Is Um w
  222. %D测试
  223. %D测试
  224. UmD = [];
  225. for ij = 1:STIME
  226. ij
  227. for j = 1:4
  228. ind = 0;
  229. ind = ind + 1;
  230. %计算Id
  231. for n1 = 1:N_in
  232. Id{n1} = func_Id(Rd(n1),rd(n1),Ws(n1,4),Lines(n1,j),Time2,dt,tpre(n1,j));
  233. end
  234. %计算Is
  235. Is{j} = func_Is(N_in,rs,Ws(:,4),Lines(:,j),Time2,dt,tpre(n1,j));
  236. %计算u
  237. Um(j,:) = func_um(N_in,Ws(:,4),Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
  238. end
  239. UmD = [UmD,Um];
  240. end
  241. figure;
  242. plot(1:Time2*STIME,UmD(1,:),'r');
  243. hold on
  244. plot(1:Time2*STIME,UmD(2,:),'g');
  245. hold on
  246. plot(1:Time2*STIME,UmD(3,:),'b');
  247. hold on
  248. plot(1:Time2*STIME,UmD(4,:),'c');
  249. hold off
  250. legend('Neuron1','Neuron2','Neuron3','Neuron4');
  251. axis([1,Time2*STIME,0,30]);
  252. clear Id Is Um w
  253. %连续码流测试
  254. %连续码流测试
  255. Umss = [];
  256. for ij = 1:STIME
  257. ij
  258. if ij>=1 &ij <= 2
  259. W = Ws(:,3);
  260. end
  261. if ij>=3 &ij <= 5
  262. W = Ws(:,4);
  263. end
  264. if ij>=6 &ij <= 8
  265. W = Ws(:,1);
  266. end
  267. if ij>=9 &ij <= 12
  268. W = Ws(:,3);
  269. end
  270. if ij>=13 &ij <= 15
  271. W = Ws(:,1);
  272. end
  273. if ij>=16 &ij <= 18
  274. W = Ws(:,2);
  275. end
  276. if ij>=19 &ij <= 21
  277. W = Ws(:,3);
  278. end
  279. if ij>=22 &ij <= 24
  280. W = Ws(:,4);
  281. end
  282. for j = 1:4
  283. ind = 0;
  284. ind = ind + 1;
  285. %计算Id
  286. for n1 = 1:N_in
  287. Id{n1} = func_Id(Rd(n1),rd(n1),W(n1),Lines(n1,j),Time2,dt,tpre(n1,j));
  288. end
  289. %计算Is
  290. Is{j} = func_Is(N_in,rs,W,Lines(:,j),Time2,dt,tpre(n1,j));
  291. %计算u
  292. Um(j,:) = func_um(N_in,W,Rm,Id,Is{j},rm,Time2,dt,Lines(:,j),tpre(n1,j),vth);
  293. end
  294. Umss = [Umss,Um];
  295. end
  296. figure;
  297. plot(1:Time2*STIME,Umss(1,:),'r');
  298. hold on
  299. plot(1:Time2*STIME,Umss(2,:),'g');
  300. hold on
  301. plot(1:Time2*STIME,Umss(3,:),'b');
  302. hold on
  303. plot(1:Time2*STIME,Umss(4,:),'c');
  304. hold off
  305. legend('Neuron1','Neuron2','Neuron3','Neuron4');
  306. axis([1,Time2*STIME,0,30]);
  307. clear Id Is Um w
  308. %%
  309. %Fig.5.Weight distribution
  310. figure;
  311. subplot(121);
  312. bar3(w0(1:15,:),0.8,'r');hold on
  313. bar3(w0(1:12,:),0.8,'y');hold on
  314. bar3(w0(1:9,:) ,0.8,'g');hold on
  315. bar3(w0(1:6,:) ,0.8,'c');hold on
  316. bar3(w0(1:3,:) ,0.8,'b');hold on
  317. xlabel('Output Neruon');
  318. ylabel('Input Neruon');
  319. zlabel('Weight');
  320. title('Before training');
  321. axis([0,5,0,16,0,1.3]);
  322. view([-126,36]);
  323. subplot(122);
  324. bar3(Ws(1:15,:),0.8,'r');hold on
  325. bar3(Ws(1:12,:),0.8,'y');hold on
  326. bar3(Ws(1:9,:) ,0.8,'g');hold on
  327. bar3(Ws(1:6,:) ,0.8,'c');hold on
  328. bar3(Ws(1:3,:) ,0.8,'b');hold on
  329. xlabel('Output Neruon');
  330. ylabel('Input Neruon');
  331. zlabel('Weight');
  332. title('After training');
  333. axis([0,5,0,16,0,1.3]);
  334. view([-126,36]);

四、仿真结论分析

     将SNN进行仿真,并得到类似论文中的仿真效果,具体的仿真结果如下图所示:

 

 

上述就是实际的仿真效果图。

五、参考文献

[1] Gupta A ,  Long L N . Hebbian learning with winner take all for spiking neural networks[C]// International Joint Conference on Neural Networks. IEEE, 2009.A05-12

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/647576
推荐阅读
相关标签
  

闽ICP备14008679号