赞
踩
由于需要做的是时间预测,所以先选定了LSTM来做。由于不给调包,自己理解不够深入,所以自己手写了一个LSTM,但是这个效果很不好,最后没有用上这个模型,但是也算是自己手撕了这么多行代码,发上来和大家分享吧。这个代码主要的公式都是参考的https://zybuluo.com/hanbingtao/note/581764。但是这个作者的有些公式似乎有点纰漏,如公式(57)-(62)的向量转置似乎有误。但是总的来说,还是写得非常好的文章,介绍地非常详细,也郑重感谢作者~话不多说,直接上代码吧。
由于最后使用的并不是这段代码,也就是说相当于白写了,但是这个过程还是很能得到锻炼的,算是初步了解了神经网络,而且极大提高了代码能力···这些代码几乎都是一个人在两天之内写出来的。本来就不是科班出身,学的是光学,平时代码量也不大,能做到这样,也算不错了吧感觉,哈哈。
1、类RunLSTM负责的是总的前向传播和后向误差传导过程,就是整个训练的过程
- package LSTMUnit;
- import java.util.ArrayList;
- import java.util.HashMap;
- import matrixUnits.MatrixUnits;
- import dataUtil.*;
- import LSTMUnit.SupportFunction;
-
-
- public class RunLSTM {//最终在这里执行
- public static void main(String[] args) {
- long startTime=System.currentTimeMillis();
-
- String filePath="C:\\Users\\lulu\\FangCloudV2\\个人文件\\华为比赛相关\\20180404lstm\\TrainData.txt";
- String[] ecsContent=FileUtil.read(filePath, null);
- double[][] allData=DataUtilLstmNew.loadDataFromStringArraynoCPUMEMNoWeek(ecsContent);
-
- allData=MatrixUnits.matrixT(allData);
- allData=InputDataProcess.addWeekDay(allData, 3);
- // MatrixUnits.printMatrix(allData);
- // System.out.println("**********************************");
- // allData=InputDataProcess.log(allData,1, Math.E);
- // MatrixUnits.printMatrix(allData);
- // System.exit(0);
-
- int dx=MatrixUnits.getRow(allData);
- int days=7;//一次输入多少天数据。下标都从0开始
- int dhc=15;//输出的向量的长度
-
- double learningRate=0.01;
- //下面都是计算过程中需要用到的变量,一共要维护8个表
- ArrayList<HashMap<String, double[][]>> deltaList=new ArrayList<HashMap<String, double[][]>>();//存放各天的误差项
- //days个数据每天有deltat,deltaot,deltaft,deltait,deltagt,共5类
- ArrayList<HashMap<String, double[][]>> gradList=new ArrayList<HashMap<String, double[][]>>();//存放各天的权重梯度和偏置项梯度
- //days个数据每天有Wfhgradt,Wihgradt,Wghgradt,Wohgradt,Wfxgradt,Wixgradt,Wgxgradt,Woxgradt,
- //bfgradt,bigradt,bggradt,bogradt,共12类
- ArrayList<double[][]> ht_1andLasthtList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
- ArrayList<double[][]> ct_landLastctList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
- ArrayList<HashMap<String, double[][]>> gateList=new ArrayList<HashMap<String, double[][]>>();//存各种门,days个数据
- //days个数据每天有ft,it,gt,ot共四类
- HashMap<String, double[][]> finalGrad=new HashMap<String, double[][]>();
- //最终的权重梯度,Wfhgrad,Wihgrad,Wghgrad,Wohgrad,Wfxgrad,Wixgrad,Wgxgrad,Woxgrad,
- //bfgrad,bigrad,bggrad,bograd,共12类
- ArrayList<HashMap<String, double[][]>> netList=new ArrayList<HashMap<String, double[][]>>();
- //days个数据,存储每个门的加权输入,分别有netft,netit,netgt,netot
- ArrayList<double[][]> xtList=new ArrayList<double[][]>();//输入表
-
- SingleCell cell=new SingleCell(dx, dhc);
- // System.out.println("firstwoh");
- // MatrixUnits.printMatrix(cell.Woh);
- initTempList(xtList,netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList, finalGrad, days);
- int predictDays=7;
- double[][] predictResult=new double[dhc][0];
- train(xtList, netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList,finalGrad,
- allData, days, cell, learningRate,dhc,dx,predictDays,predictResult);
-
- // System.out.println("lasttwoh");
- // MatrixUnits.printMatrix(cell.Woh);
-
- long endTime=System.currentTimeMillis();//记录结束时间
- double excTime=(double)(endTime-startTime)/1000;
- System.out.println("Running "+excTime+"s");
-
-
- }
-
- public static void initTempList(
- ArrayList<double[][]> xtList,
- ArrayList<HashMap<String, double[][]>> netList,
- ArrayList<HashMap<String, double[][]>> gateList,
- ArrayList<double[][]> ht_1andLasthtList,
- ArrayList<double[][]> ct_landLastctList,
- ArrayList<HashMap<String, double[][]>> deltaList,
- ArrayList<HashMap<String, double[][]>> gradList,
- HashMap<String, double[][]> finalGrad,int days) {//对7个存储中间变量的列表进行初始化
- //System.out.println("init six TempLists start:");
-
- finalGrad.put("Wfhgrad", null);
- finalGrad.put("Wihgrad", null);
- finalGrad.put("Wghgrad", null);
- finalGrad.put("Wohgrad", null);
- finalGrad.put("Wfxgrad", null);
- finalGrad.put("Wixgrad", null);
- finalGrad.put("Wgxgrad", null);
- finalGrad.put("Woxgrad", null);
- finalGrad.put("bfgrad", null);
- finalGrad.put("bigrad", null);
- finalGrad.put("bggrad", null);
- finalGrad.put("bograd", null);
-
- ht_1andLasthtList.add(null);
- ct_landLastctList.add(null);
- for(int i=0;i<days;i++) {
- xtList.add(null);
-
- netList.add(new HashMap<String, double[][]>());
- netList.get(netList.size()-1).put("netft", null);
- netList.get(netList.size()-1).put("netit", null);
- netList.get(netList.size()-1).put("netgt", null);
- netList.get(netList.size()-1).put("netot", null);
-
- gateList.add(new HashMap<String, double[][]>());
- gateList.get(gateList.size()-1).put("ft", null);
- gateList.get(gateList.size()-1).put("it", null);
- gateList.get(gateList.size()-1).put("gt", null);
- gateList.get(gateList.size()-1).put("ot", null);
-
-
- ht_1andLasthtList.add(null);
- ct_landLastctList.add(null);
-
- deltaList.add(new HashMap<String, double[][]>());
- gradList.add(new HashMap<String, double[][]>());
-
- deltaList.get(i).put("deltat", null);
- deltaList.get(i).put("deltaot", null);
- deltaList.get(i).put("deltaft", null);
- deltaList.get(i).put("deltait", null);
- deltaList.get(i).put("deltagt", null);
-
- gradList.get(i).put("Wfhgradt", null);
- gradList.get(i).put("Wihgradt", null);
- gradList.get(i).put("Wghgradt", null);
- gradList.get(i).put("Wohgradt", null);
- gradList.get(i).put("Wfxgradt", null);
- gradList.get(i).put("Wixgradt", null);
- gradList.get(i).put("Wgxgradt", null);
- gradList.get(i).put("Woxgradt", null);
- gradList.get(i).put("bfgradt", null);
- gradList.get(i).put("bigradt", null);
- gradList.get(i).put("bggradt", null);
- gradList.get(i).put("bogradt", null);
- }
- //System.out.println("init six TempLists end:");
- }
- public static void clearTempList(
- ArrayList<HashMap<String, double[][]>> netList,
- ArrayList<HashMap<String, double[][]>> gateList,
- ArrayList<double[][]> ht_1andLasthtList,
- ArrayList<double[][]> ct_landLastctList,
- ArrayList<HashMap<String, double[][]>> deltaList,
- ArrayList<HashMap<String, double[][]>> gradList,
- HashMap<String, double[][]> finalGrad) {//对7个存储中间变量的列表进行清除
- System.out.println("clean six TempLists start:");
- if(netList!=null)
- netList.clear();
- if(gateList!=null)
- gateList.clear();
- if(ht_1andLasthtList!=null)
- ht_1andLasthtList.clear();
- if(ct_landLastctList!=null)
- ct_landLastctList.clear();
- if(deltaList!=null)
- deltaList.clear();
- if(gradList!=null)
- gradList.clear();
- if(finalGrad!=null)
- finalGrad.clear();
- System.out.println("clean six TempLists end:");
- }
- public static void forwardProcess(
- ArrayList<double[][]> xtList,
- ArrayList<HashMap<String, double[][]>> netList,
- ArrayList<HashMap<String, double[][]>> gateList,
- ArrayList<double[][]> ct_landLastctList,
- ArrayList<double[][]> ht_1andLasthtList,
- double[][] trainCurrentData,SingleCell cell,int dhc,int dx) {//其实所谓的前向和后向过程,就是利用输入的一个时间窗口,
- //对一次时间窗口 移动更新各种临时变量,最后更新权值矩阵和偏置量
- //对于forwardProcess,就是通过各时刻输入表,更新各时刻ct表、net表、ht表、gate表
- //System.out.println("forwardProcess start:");
- int window=MatrixUnits.getCol(trainCurrentData);
- //1、更新输入表
- xtList.clear();//xtList是可以clear掉的,没啥问题
- for(int i=0;i<window;i++) {
- double[][] temp=MatrixUnits.getPartOfAMatrix(trainCurrentData, 0, dx-1, i, i);
- xtList.add((double[][])temp.clone());
- }
- //2、更新ct和ht表的第一项,第一次循环的时候要随机化
- if(ct_landLastctList.get(0)==null||ht_1andLasthtList.get(0)==null) {
- double[][] temp=MatrixUnits.getARandomMatrix(dhc, 1, 0, 1);
- ct_landLastctList.set(0, temp);
- ht_1andLasthtList.set(0, (double[][])temp.clone());
- }
- else {//以后的循环直接往左推移
- double[][] tempct_1=(double[][])ct_landLastctList.get(1).clone();
- double[][] tempct_2=(double[][])ht_1andLasthtList.get(1).clone();
- //clearTempList(netList, gateList, ht_1andLasthtList, ct_landLastctList, null, null, null);
- ct_landLastctList.set(0, tempct_1);
- ht_1andLasthtList.set(0, tempct_2);
- }
- //3、通过前向计算依次更新ct表,ht表,nett表和gete表,
- for(int i=1;i<window+1;i++) {
- double[][] xt=xtList.get(i-1);
- double[][] ct_1=ct_landLastctList.get(i-1);
- double[][] ht_1=ht_1andLasthtList.get(i-1);
- double[][] netft,netit,netgt,netot,ft,it,gt,ct,ot,ht;
-
- netft=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wfh, ht_1),
- MatrixUnits.matrixNormalMul(cell.Wfx, xt)),cell.bf);
- ft=SupportFunction.sigmoid(netft);
- //输入门的计算
- netit=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wih, ht_1),
- MatrixUnits.matrixNormalMul(cell.Wix, xt)),cell.bi);
- it=SupportFunction.sigmoid(netit);
- //描述当前输入的单元状态
- netgt=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wgh, ht_1),
- MatrixUnits.matrixNormalMul(cell.Wgx, xt)),cell.bg);
- gt=SupportFunction.tanh(netgt);
- //输出单元状态
- ct=MatrixUnits.matrixAdd(MatrixUnits.matrixHadamardMul(ft, ct_1), MatrixUnits.matrixHadamardMul(it, gt));
- //输出门
- netot=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Woh, ht_1),
- MatrixUnits.matrixNormalMul(cell.Wox, xt)),cell.bo);
- ot=SupportFunction.sigmoid(netot);
- //最终输出
- ht=MatrixUnits.matrixHadamardMul(ot, SupportFunction.tanh(ct));
-
- ct_landLastctList.set(i, ct);
- ht_1andLasthtList.set(i, ht);
- gateList.get(i-1).put("ft", ft);
- gateList.get(i-1).put("it", it);
- gateList.get(i-1).put("gt", gt);
- gateList.get(i-1).put("ot", ot);
- netList.get(i-1).put("netft", netft);
- netList.get(i-1).put("netit", netit);
- netList.get(i-1).put("netgt", netgt);
- netList.get(i-1).put("netot", netot);
- }
- //System.out.println("forwardProcess end:");
- }
- public static void backwardProcess(
- ArrayList<double[][]> xtList,
- ArrayList<HashMap<String, double[][]>> deltaList,
- ArrayList<double[][]> ct_landLastctList,
- ArrayList<double[][]> ht_1andLasthtList,
- ArrayList<HashMap<String, double[][]>> gateList,
- ArrayList<HashMap<String, double[][]>> gradList,
- HashMap<String, double[][]> finalGrad,
- double[][] trainCurrentData,
- SingleCell cell,int dhc,int dx,double[][] nextTime) {
- //System.out.println("backwardProcess start:");
- int window=MatrixUnits.getCol(trainCurrentData);
- //1、首先更新deltaList
- double[][] ct=ct_landLastctList.get(window);
- double[][] ot=gateList.get(window-1).get("ot");
- double[][] ft=gateList.get(window-1).get("ft");
- double[][] it=gateList.get(window-1).get("it");
- double[][] gt=gateList.get(window-1).get("gt");
- double[][] ct_1=ct_landLastctList.get(window-1);
-
- double[][] matrix1=MatrixUnits.getAZeromatrix(dhc, 1);//全1矩阵,这个一直到后面都能用的
-
- double[][] tanhct21_=MatrixUnits.matrixSub(matrix1,//1-tanh^2ct
- MatrixUnits.matrixHadamardMul(SupportFunction.tanh(ct), SupportFunction.tanh(ct)));
-
- double[][] deltat=new double[dhc][1];
- // //**************这里取的是误差函数1/2(t^2-y^2)的导数的相反数,而且是各个输出都要计算然后相加
- // for(int i=0;i<window;i++){
- // double[][] yData=ht_1andLasthtList.get(i+1);
- // double[][] tempdelta=new double[dhc][1];
- // double[][] target=new double[dhc][1];
- // if(i<window-1)
- // target=xtList.get(i+1);
- // else
- // target=nextTime.clone();
- // for(int j=0;j<dhc;j++)
- // tempdelta[j][0]=yData[i][0]*(1-yData[i][0])*(target[i][0]-yData[i][0]);
- // deltat=MatrixUnits.matrixAdd(deltat, tempdelta);
- // }
- // //*************************
- // **********这里是只算最后一个输出的delta
- double[][] yData=ht_1andLasthtList.get(window);
- double[][] target=nextTime.clone();
- for(int j=0;j<dhc;j++)
- deltat[j][0]=(target[j][0]-yData[j][0])*yData[j][0]*(1-yData[j][0]);
- // *************************
- deltaList.get(window-1).put("deltat", deltat);
-
-
- double[][] deltaot=MatrixUnits.matrixHadamardMul(MatrixUnits.matrixHadamardMul(MatrixUnits.
- matrixHadamardMul(deltat, SupportFunction.tanh(ct)),ot),MatrixUnits.matrixSub(matrix1,ot));
-
- double[][] deltaft=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
- MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(ct_1, MatrixUnits.
- matrixHadamardMul(ft, MatrixUnits.matrixSub(matrix1, ft))))));
-
- double[][] deltait=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
- MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(gt, MatrixUnits.
- matrixHadamardMul(it, MatrixUnits.matrixSub(matrix1, it))))));
-
- double[][] deltagt=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
- MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(it,
- MatrixUnits.matrixSub(matrix1, MatrixUnits.matrixHadamardMul(gt, gt))))));
-
- deltaList.get(window-1).put("deltaot", deltaot);
- deltaList.get(window-1).put("deltaft", deltaft);
- deltaList.get(window-1).put("deltait", deltait);
- deltaList.get(window-1).put("deltagt", deltagt);
- //截止到现在为止,更新完了最新一天的deltat,deltaot,deltaft,deltait,deltagt,然后后面的天就一直依靠他们来不断进行更新
- for(int i=window-2;i>=0;i--){
- double[][] tempct=ct_landLastctList.get(i+1);
- double[][] tempct_1=ct_landLastctList.get(i);
- double[][] tempot=gateList.get(i).get("ot");
- double[][] tempft=gateList.get(i).get("ft");
- double[][] tempit=gateList.get(i).get("it");
- double[][] tempgt=gateList.get(i).get("gt");
-
- double[][] temptanhct21_=MatrixUnits.matrixSub(matrix1,
- MatrixUnits.matrixHadamardMul(SupportFunction.tanh(tempct), SupportFunction.tanh(tempct)));
-
- double[][] tempdeltaotplus1T=
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。