当前位置:   article > 正文

2018华为软件精英挑战赛系列3——自己手撕的LSTM_手撕lstm代码

手撕lstm代码

    由于需要做的是时间预测,所以先选定了LSTM来做。由于不给调包,自己理解不够深入,所以自己手写了一个LSTM,但是这个效果很不好,最后没有用上这个模型,但是也算是自己手撕了这么多行代码,发上来和大家分享吧。这个代码主要的公式都是参考的https://zybuluo.com/hanbingtao/note/581764。但是这个作者的有些公式似乎有点纰漏,如公式(57)-(62)的向量转置似乎有误。但是总的来说,还是写得非常好的文章,介绍地非常详细,也郑重感谢作者~话不多说,直接上代码吧。

    由于最后使用的并不是这段代码,也就是说相当于白写了,但是这个过程还是很能得到锻炼的,算是初步了解了神经网络,而且极大提高了代码能力···这些代码几乎都是一个人在两天之内写出来的。本来就不是科班出身,学的是光学,平时代码量也不大,能做到这样,也算不错了吧感觉,哈哈。

1、类RunLSTM负责的是总的前向传播和后向误差传导过程,就是整个训练的过程

  1. package LSTMUnit;
  2. import java.util.ArrayList;
  3. import java.util.HashMap;
  4. import matrixUnits.MatrixUnits;
  5. import dataUtil.*;
  6. import LSTMUnit.SupportFunction;
  7. public class RunLSTM {//最终在这里执行
  8. public static void main(String[] args) {
  9. long startTime=System.currentTimeMillis();
  10. String filePath="C:\\Users\\lulu\\FangCloudV2\\个人文件\\华为比赛相关\\20180404lstm\\TrainData.txt";
  11. String[] ecsContent=FileUtil.read(filePath, null);
  12. double[][] allData=DataUtilLstmNew.loadDataFromStringArraynoCPUMEMNoWeek(ecsContent);
  13. allData=MatrixUnits.matrixT(allData);
  14. allData=InputDataProcess.addWeekDay(allData, 3);
  15. // MatrixUnits.printMatrix(allData);
  16. // System.out.println("**********************************");
  17. // allData=InputDataProcess.log(allData,1, Math.E);
  18. // MatrixUnits.printMatrix(allData);
  19. // System.exit(0);
  20. int dx=MatrixUnits.getRow(allData);
  21. int days=7;//一次输入多少天数据。下标都从0开始
  22. int dhc=15;//输出的向量的长度
  23. double learningRate=0.01;
  24. //下面都是计算过程中需要用到的变量,一共要维护8个表
  25. ArrayList<HashMap<String, double[][]>> deltaList=new ArrayList<HashMap<String, double[][]>>();//存放各天的误差项
  26. //days个数据每天有deltat,deltaot,deltaft,deltait,deltagt,共5类
  27. ArrayList<HashMap<String, double[][]>> gradList=new ArrayList<HashMap<String, double[][]>>();//存放各天的权重梯度和偏置项梯度
  28. //days个数据每天有Wfhgradt,Wihgradt,Wghgradt,Wohgradt,Wfxgradt,Wixgradt,Wgxgradt,Woxgradt,
  29. //bfgradt,bigradt,bggradt,bogradt,共12类
  30. ArrayList<double[][]> ht_1andLasthtList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
  31. ArrayList<double[][]> ct_landLastctList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
  32. ArrayList<HashMap<String, double[][]>> gateList=new ArrayList<HashMap<String, double[][]>>();//存各种门,days个数据
  33. //days个数据每天有ft,it,gt,ot共四类
  34. HashMap<String, double[][]> finalGrad=new HashMap<String, double[][]>();
  35. //最终的权重梯度,Wfhgrad,Wihgrad,Wghgrad,Wohgrad,Wfxgrad,Wixgrad,Wgxgrad,Woxgrad,
  36. //bfgrad,bigrad,bggrad,bograd,共12类
  37. ArrayList<HashMap<String, double[][]>> netList=new ArrayList<HashMap<String, double[][]>>();
  38. //days个数据,存储每个门的加权输入,分别有netft,netit,netgt,netot
  39. ArrayList<double[][]> xtList=new ArrayList<double[][]>();//输入表
  40. SingleCell cell=new SingleCell(dx, dhc);
  41. // System.out.println("firstwoh");
  42. // MatrixUnits.printMatrix(cell.Woh);
  43. initTempList(xtList,netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList, finalGrad, days);
  44. int predictDays=7;
  45. double[][] predictResult=new double[dhc][0];
  46. train(xtList, netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList,finalGrad,
  47. allData, days, cell, learningRate,dhc,dx,predictDays,predictResult);
  48. // System.out.println("lasttwoh");
  49. // MatrixUnits.printMatrix(cell.Woh);
  50. long endTime=System.currentTimeMillis();//记录结束时间
  51. double excTime=(double)(endTime-startTime)/1000;
  52. System.out.println("Running "+excTime+"s");
  53. }
  54. public static void initTempList(
  55. ArrayList<double[][]> xtList,
  56. ArrayList<HashMap<String, double[][]>> netList,
  57. ArrayList<HashMap<String, double[][]>> gateList,
  58. ArrayList<double[][]> ht_1andLasthtList,
  59. ArrayList<double[][]> ct_landLastctList,
  60. ArrayList<HashMap<String, double[][]>> deltaList,
  61. ArrayList<HashMap<String, double[][]>> gradList,
  62. HashMap<String, double[][]> finalGrad,int days) {//对7个存储中间变量的列表进行初始化
  63. //System.out.println("init six TempLists start:");
  64. finalGrad.put("Wfhgrad", null);
  65. finalGrad.put("Wihgrad", null);
  66. finalGrad.put("Wghgrad", null);
  67. finalGrad.put("Wohgrad", null);
  68. finalGrad.put("Wfxgrad", null);
  69. finalGrad.put("Wixgrad", null);
  70. finalGrad.put("Wgxgrad", null);
  71. finalGrad.put("Woxgrad", null);
  72. finalGrad.put("bfgrad", null);
  73. finalGrad.put("bigrad", null);
  74. finalGrad.put("bggrad", null);
  75. finalGrad.put("bograd", null);
  76. ht_1andLasthtList.add(null);
  77. ct_landLastctList.add(null);
  78. for(int i=0;i<days;i++) {
  79. xtList.add(null);
  80. netList.add(new HashMap<String, double[][]>());
  81. netList.get(netList.size()-1).put("netft", null);
  82. netList.get(netList.size()-1).put("netit", null);
  83. netList.get(netList.size()-1).put("netgt", null);
  84. netList.get(netList.size()-1).put("netot", null);
  85. gateList.add(new HashMap<String, double[][]>());
  86. gateList.get(gateList.size()-1).put("ft", null);
  87. gateList.get(gateList.size()-1).put("it", null);
  88. gateList.get(gateList.size()-1).put("gt", null);
  89. gateList.get(gateList.size()-1).put("ot", null);
  90. ht_1andLasthtList.add(null);
  91. ct_landLastctList.add(null);
  92. deltaList.add(new HashMap<String, double[][]>());
  93. gradList.add(new HashMap<String, double[][]>());
  94. deltaList.get(i).put("deltat", null);
  95. deltaList.get(i).put("deltaot", null);
  96. deltaList.get(i).put("deltaft", null);
  97. deltaList.get(i).put("deltait", null);
  98. deltaList.get(i).put("deltagt", null);
  99. gradList.get(i).put("Wfhgradt", null);
  100. gradList.get(i).put("Wihgradt", null);
  101. gradList.get(i).put("Wghgradt", null);
  102. gradList.get(i).put("Wohgradt", null);
  103. gradList.get(i).put("Wfxgradt", null);
  104. gradList.get(i).put("Wixgradt", null);
  105. gradList.get(i).put("Wgxgradt", null);
  106. gradList.get(i).put("Woxgradt", null);
  107. gradList.get(i).put("bfgradt", null);
  108. gradList.get(i).put("bigradt", null);
  109. gradList.get(i).put("bggradt", null);
  110. gradList.get(i).put("bogradt", null);
  111. }
  112. //System.out.println("init six TempLists end:");
  113. }
  114. public static void clearTempList(
  115. ArrayList<HashMap<String, double[][]>> netList,
  116. ArrayList<HashMap<String, double[][]>> gateList,
  117. ArrayList<double[][]> ht_1andLasthtList,
  118. ArrayList<double[][]> ct_landLastctList,
  119. ArrayList<HashMap<String, double[][]>> deltaList,
  120. ArrayList<HashMap<String, double[][]>> gradList,
  121. HashMap<String, double[][]> finalGrad) {//对7个存储中间变量的列表进行清除
  122. System.out.println("clean six TempLists start:");
  123. if(netList!=null)
  124. netList.clear();
  125. if(gateList!=null)
  126. gateList.clear();
  127. if(ht_1andLasthtList!=null)
  128. ht_1andLasthtList.clear();
  129. if(ct_landLastctList!=null)
  130. ct_landLastctList.clear();
  131. if(deltaList!=null)
  132. deltaList.clear();
  133. if(gradList!=null)
  134. gradList.clear();
  135. if(finalGrad!=null)
  136. finalGrad.clear();
  137. System.out.println("clean six TempLists end:");
  138. }
  139. public static void forwardProcess(
  140. ArrayList<double[][]> xtList,
  141. ArrayList<HashMap<String, double[][]>> netList,
  142. ArrayList<HashMap<String, double[][]>> gateList,
  143. ArrayList<double[][]> ct_landLastctList,
  144. ArrayList<double[][]> ht_1andLasthtList,
  145. double[][] trainCurrentData,SingleCell cell,int dhc,int dx) {//其实所谓的前向和后向过程,就是利用输入的一个时间窗口,
  146. //对一次时间窗口 移动更新各种临时变量,最后更新权值矩阵和偏置量
  147. //对于forwardProcess,就是通过各时刻输入表,更新各时刻ct表、net表、ht表、gate表
  148. //System.out.println("forwardProcess start:");
  149. int window=MatrixUnits.getCol(trainCurrentData);
  150. //1、更新输入表
  151. xtList.clear();//xtList是可以clear掉的,没啥问题
  152. for(int i=0;i<window;i++) {
  153. double[][] temp=MatrixUnits.getPartOfAMatrix(trainCurrentData, 0, dx-1, i, i);
  154. xtList.add((double[][])temp.clone());
  155. }
  156. //2、更新ct和ht表的第一项,第一次循环的时候要随机化
  157. if(ct_landLastctList.get(0)==null||ht_1andLasthtList.get(0)==null) {
  158. double[][] temp=MatrixUnits.getARandomMatrix(dhc, 1, 0, 1);
  159. ct_landLastctList.set(0, temp);
  160. ht_1andLasthtList.set(0, (double[][])temp.clone());
  161. }
  162. else {//以后的循环直接往左推移
  163. double[][] tempct_1=(double[][])ct_landLastctList.get(1).clone();
  164. double[][] tempct_2=(double[][])ht_1andLasthtList.get(1).clone();
  165. //clearTempList(netList, gateList, ht_1andLasthtList, ct_landLastctList, null, null, null);
  166. ct_landLastctList.set(0, tempct_1);
  167. ht_1andLasthtList.set(0, tempct_2);
  168. }
  169. //3、通过前向计算依次更新ct表,ht表,nett表和gete表,
  170. for(int i=1;i<window+1;i++) {
  171. double[][] xt=xtList.get(i-1);
  172. double[][] ct_1=ct_landLastctList.get(i-1);
  173. double[][] ht_1=ht_1andLasthtList.get(i-1);
  174. double[][] netft,netit,netgt,netot,ft,it,gt,ct,ot,ht;
  175. netft=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wfh, ht_1),
  176. MatrixUnits.matrixNormalMul(cell.Wfx, xt)),cell.bf);
  177. ft=SupportFunction.sigmoid(netft);
  178. //输入门的计算
  179. netit=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wih, ht_1),
  180. MatrixUnits.matrixNormalMul(cell.Wix, xt)),cell.bi);
  181. it=SupportFunction.sigmoid(netit);
  182. //描述当前输入的单元状态
  183. netgt=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wgh, ht_1),
  184. MatrixUnits.matrixNormalMul(cell.Wgx, xt)),cell.bg);
  185. gt=SupportFunction.tanh(netgt);
  186. //输出单元状态
  187. ct=MatrixUnits.matrixAdd(MatrixUnits.matrixHadamardMul(ft, ct_1), MatrixUnits.matrixHadamardMul(it, gt));
  188. //输出门
  189. netot=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Woh, ht_1),
  190. MatrixUnits.matrixNormalMul(cell.Wox, xt)),cell.bo);
  191. ot=SupportFunction.sigmoid(netot);
  192. //最终输出
  193. ht=MatrixUnits.matrixHadamardMul(ot, SupportFunction.tanh(ct));
  194. ct_landLastctList.set(i, ct);
  195. ht_1andLasthtList.set(i, ht);
  196. gateList.get(i-1).put("ft", ft);
  197. gateList.get(i-1).put("it", it);
  198. gateList.get(i-1).put("gt", gt);
  199. gateList.get(i-1).put("ot", ot);
  200. netList.get(i-1).put("netft", netft);
  201. netList.get(i-1).put("netit", netit);
  202. netList.get(i-1).put("netgt", netgt);
  203. netList.get(i-1).put("netot", netot);
  204. }
  205. //System.out.println("forwardProcess end:");
  206. }
  207. public static void backwardProcess(
  208. ArrayList<double[][]> xtList,
  209. ArrayList<HashMap<String, double[][]>> deltaList,
  210. ArrayList<double[][]> ct_landLastctList,
  211. ArrayList<double[][]> ht_1andLasthtList,
  212. ArrayList<HashMap<String, double[][]>> gateList,
  213. ArrayList<HashMap<String, double[][]>> gradList,
  214. HashMap<String, double[][]> finalGrad,
  215. double[][] trainCurrentData,
  216. SingleCell cell,int dhc,int dx,double[][] nextTime) {
  217. //System.out.println("backwardProcess start:");
  218. int window=MatrixUnits.getCol(trainCurrentData);
  219. //1、首先更新deltaList
  220. double[][] ct=ct_landLastctList.get(window);
  221. double[][] ot=gateList.get(window-1).get("ot");
  222. double[][] ft=gateList.get(window-1).get("ft");
  223. double[][] it=gateList.get(window-1).get("it");
  224. double[][] gt=gateList.get(window-1).get("gt");
  225. double[][] ct_1=ct_landLastctList.get(window-1);
  226. double[][] matrix1=MatrixUnits.getAZeromatrix(dhc, 1);//全1矩阵,这个一直到后面都能用的
  227. double[][] tanhct21_=MatrixUnits.matrixSub(matrix1,//1-tanh^2ct
  228. MatrixUnits.matrixHadamardMul(SupportFunction.tanh(ct), SupportFunction.tanh(ct)));
  229. double[][] deltat=new double[dhc][1];
  230. // //**************这里取的是误差函数1/2(t^2-y^2)的导数的相反数,而且是各个输出都要计算然后相加
  231. // for(int i=0;i<window;i++){
  232. // double[][] yData=ht_1andLasthtList.get(i+1);
  233. // double[][] tempdelta=new double[dhc][1];
  234. // double[][] target=new double[dhc][1];
  235. // if(i<window-1)
  236. // target=xtList.get(i+1);
  237. // else
  238. // target=nextTime.clone();
  239. // for(int j=0;j<dhc;j++)
  240. // tempdelta[j][0]=yData[i][0]*(1-yData[i][0])*(target[i][0]-yData[i][0]);
  241. // deltat=MatrixUnits.matrixAdd(deltat, tempdelta);
  242. // }
  243. // //*************************
  244. // **********这里是只算最后一个输出的delta
  245. double[][] yData=ht_1andLasthtList.get(window);
  246. double[][] target=nextTime.clone();
  247. for(int j=0;j<dhc;j++)
  248. deltat[j][0]=(target[j][0]-yData[j][0])*yData[j][0]*(1-yData[j][0]);
  249. // *************************
  250. deltaList.get(window-1).put("deltat", deltat);
  251. double[][] deltaot=MatrixUnits.matrixHadamardMul(MatrixUnits.matrixHadamardMul(MatrixUnits.
  252. matrixHadamardMul(deltat, SupportFunction.tanh(ct)),ot),MatrixUnits.matrixSub(matrix1,ot));
  253. double[][] deltaft=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
  254. MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(ct_1, MatrixUnits.
  255. matrixHadamardMul(ft, MatrixUnits.matrixSub(matrix1, ft))))));
  256. double[][] deltait=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
  257. MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(gt, MatrixUnits.
  258. matrixHadamardMul(it, MatrixUnits.matrixSub(matrix1, it))))));
  259. double[][] deltagt=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
  260. MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(it,
  261. MatrixUnits.matrixSub(matrix1, MatrixUnits.matrixHadamardMul(gt, gt))))));
  262. deltaList.get(window-1).put("deltaot", deltaot);
  263. deltaList.get(window-1).put("deltaft", deltaft);
  264. deltaList.get(window-1).put("deltait", deltait);
  265. deltaList.get(window-1).put("deltagt", deltagt);
  266. //截止到现在为止,更新完了最新一天的deltat,deltaot,deltaft,deltait,deltagt,然后后面的天就一直依靠他们来不断进行更新
  267. for(int i=window-2;i>=0;i--){
  268. double[][] tempct=ct_landLastctList.get(i+1);
  269. double[][] tempct_1=ct_landLastctList.get(i);
  270. double[][] tempot=gateList.get(i).get("ot");
  271. double[][] tempft=gateList.get(i).get("ft");
  272. double[][] tempit=gateList.get(i).get("it");
  273. double[][] tempgt=gateList.get(i).get("gt");
  274. double[][] temptanhct21_=MatrixUnits.matrixSub(matrix1,
  275. MatrixUnits.matrixHadamardMul(SupportFunction.tanh(tempct), SupportFunction.tanh(tempct)));
  276. double[][] tempdeltaotplus1T=
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号