当前位置:   article > 正文

Java学习日记(81-90天,CNN 卷积神经网络)_java 百度旋转cnn识别

java 百度旋转cnn识别

学习地址

第 81 天: 数据集读取与存储

数据 train.format:下载地址
之前使用arff文件存储数据,现在用图片数据方式存储,按结构化的方式来存取 (m*n 点阵和类别)

  1. 这里使用了 java.util.List 类,在前面实现的数据结构,很多可以直接在 java.util 包中找到。与自己读取并使用 double[][] 来管理数据相比,List 类允许添加数据, 更加灵活。当然效率上会有点影响,可能是常数倍。
  2. tempLine.split 是以前没有使用过的功能。其实 String 类是的方法是比较丰富的,为了防止开发者乱改,该类是 final 的,不允许继承。
  3. List 是指列表里面只能存储 Instance 类型的变量。当然,Instance 的子类也行。
  4. Instance 类也是 public 的。Instance 的 label 是 Double 类型的,换成 double 也可以,如果已经确定是分类问题,换成 int 更好。
  5. 最长的就是构造函数,从文件中读入数据。
package xjx.cnn;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class Dataset {

	/**
	 * 所有实例是一个List表
	 */
	private List<Instance> instances;

	/**
	 * 标签索引
	 */
	private int labelIndex;

	/**
	 * 最大的标签
	 */
	private double maxLabel = -1;

	/**
	 *********************** 
	 * 第一个构造器.
	 *********************** 
	 */
	public Dataset() {
		labelIndex = -1;
		instances = new ArrayList<Instance>();
	}

	/**
	 *********************** 
	 * 第二个构造器.
	 * 
	 * @param paraFilename
	 *            The filename.
	 * @param paraSplitSign
	 *            Often comma.
	 * @param paraLabelIndex
	 *            Often the last column.
	 *********************** 
	 */
	public Dataset(String paraFilename, String paraSplitSign, int paraLabelIndex) {
		instances = new ArrayList<Instance>();
		labelIndex = paraLabelIndex;

		File tempFile = new File(paraFilename);
		try {
			BufferedReader tempReader = new BufferedReader(new FileReader(tempFile));
			String tempLine;
			while ((tempLine = tempReader.readLine()) != null) {
				String[] tempDatum = tempLine.split(paraSplitSign);
				if (tempDatum.length == 0) {
					continue;
				}

				double[] tempData = new double[tempDatum.length];
				for (int i = 0; i < tempDatum.length; i++)
					tempData[i] = Double.parseDouble(tempDatum[i]);
				Instance tempInstance = new Instance(tempData);
				append(tempInstance);
			}
			tempReader.close();
		} catch (IOException e) {
			e.printStackTrace();
			System.out.println("Unable to load " + paraFilename);
			System.exit(0);
		}
	}

	/**
	 *********************** 
	 * 追加一个实例
	 * 
	 * @param paraInstance
	 *            The given record.
	 *********************** 
	 */
	public void append(Instance paraInstance) {
		instances.add(paraInstance);
	}

	/**
	 *********************** 
	 * 追加一个实例
	 *********************** 
	 */
	public void append(double[] paraAttributes, Double paraLabel) {
		instances.add(new Instance(paraAttributes, paraLabel));
	}

	/**
	 *********************** 
	 * Getter.
	 *********************** 
	 */
	public Instance getInstance(int paraIndex) {
		return instances.get(paraIndex);
	}

	/**
	 *********************** 
	 * Getter.
	 *********************** 
	 */
	public int size() {
		return instances.size();
	}// Of size

	/**
	 *********************** 
	 * Getter.
	 *********************** 
	 */
	public double[] getAttributes(int paraIndex) {
		return instances.get(paraIndex).getAttributes();
	}

	/**
	 *********************** 
	 * Getter.
	 *********************** 
	 */
	public Double getLabel(int paraIndex) {
		return instances.get(paraIndex).getLabel();
	}

	/**
	 *********************** 
	 * Unit test.
	 *********************** 
	 */
	public static void main(String args[]) {
		Dataset tempData = new Dataset("d:/data/train.format", ",", 784);
		Instance tempInstance = tempData.getInstance(0);
		System.out.println("The first instance is: " + tempInstance);
	}

	/**
	 *********************** 
	 * An instance.
	 *********************** 
	 */
	public class Instance {
		/**
		 * 条件属性.
		 */
		private double[] attributes;

		/**
		 * 标签.
		 */
		private Double label;

		/**
		 *********************** 
		 * The first constructor.
		 *********************** 
		 */
		private Instance(double[] paraAttrs, Double paraLabel) {
			attributes = paraAttrs;
			label = paraLabel;
		}

		/**
		 *********************** 
		 * The second constructor.
		 *********************** 
		 */
		public Instance(double[] paraData) {
			if (labelIndex == -1)
				//无标签
				attributes = paraData;
			else {
				label = paraData[labelIndex];
				if (label > maxLabel) {
					//新标签
					maxLabel = label;
				}

				if (labelIndex == 0) {
					//第一列是标签
					attributes = Arrays.copyOfRange(paraData, 1, paraData.length);
				} else {
					attributes = Arrays.copyOfRange(paraData, 0, paraData.length - 1);
				}
			}
		}

		/**
		 *********************** 
		 * Getter.
		 *********************** 
		 */
		public double[] getAttributes() {
			return attributes;
		}

		/**
		 *********************** 
		 * Getter.
		 *********************** 
		 */
		public Double getLabel() {
			if (labelIndex == -1)
				return null;
			return label;
		}

		/**
		 *********************** 
		 * toString.
		 *********************** 
		 */
		public String toString(){
			return Arrays.toString(attributes) + ", " + label;
		}
	}
}


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228

一个管理卷积核尺寸的类. 基础代码, 在网络运行时才能理解它们的作用.

  1. 支持 Size 的相除. 但并不是所有的 Size 都可以除.
  2. 支持 Size 的相减.
package xjx.cnn;

public class Size {
	/**
	 * Cannot be changed after initialization.
	 */
	public final int width;

	/**
	 * Cannot be changed after initialization.
	 */
	public final int height;

	/**
	 *********************** 
	 * The first constructor.
	 * 
	 * @param paraWidth
	 *            The given width.
	 * @param paraHeight
	 *            The given height.
	 *********************** 
	 */
	public Size(int paraWidth, int paraHeight) {
		width = paraWidth;
		height = paraHeight;
	}// Of the first constructor

	/**
	 *********************** 
	 * Divide a scale with another one. For example (4, 12) / (2, 3) = (2, 4).
	 * 
	 * @param paraScaleSize
	 *            The given scale size.
	 * @return The new size.
	 *********************** 
	 */
	public Size divide(Size paraScaleSize) {
		int resultWidth = width / paraScaleSize.width;
		int resultHeight = height / paraScaleSize.height;
		if (resultWidth * paraScaleSize.width != width
				|| resultHeight * paraScaleSize.height != height)
			throw new RuntimeException("Unable to divide " + this + " with " + paraScaleSize);
		return new Size(resultWidth, resultHeight);
	}// Of divide

	/**
	 *********************** 
	 * Subtract a scale with another one, and add a value. For example (4, 12) -
	 * (2, 3) + 1 = (3, 10).
	 * 
	 * @param paraScaleSize
	 *            The given scale size.
	 * @param paraAppend
	 *            The appended size to both dimensions.
	 * @return The new size.
	 *********************** 
	 */
	public Size subtract(Size paraScaleSize, int paraAppend) {
		int resultWidth = width - paraScaleSize.width + paraAppend;
		int resultHeight = height - paraScaleSize.height + paraAppend;
		return new Size(resultWidth, resultHeight);
	}// Of subtract

	/**
	 *********************** 
	 * @param The
	 *            string showing itself.
	 *********************** 
	 */
	public String toString() {
		String resultString = "(" + width + ", " + height + ")";
		return resultString;
	}// Of toString

	/**
	 *********************** 
	 * Unit test.
	 *********************** 
	 */
	public static void main(String[] args) {
		Size tempSize1 = new Size(4, 6);
		Size tempSize2 = new Size(2, 2);
		System.out.println(
				"" + tempSize1 + " divide " + tempSize2 + " = " + tempSize1.divide(tempSize2));

		System.out.printf("a");

		try {
			System.out.println(
					"" + tempSize2 + " divide " + tempSize1 + " = " + tempSize2.divide(tempSize1));
		} catch (Exception ee) {
			System.out.println(ee);
		} // Of try

		System.out.println(
				"" + tempSize1 + " - " + tempSize2 + " + 1 = " + tempSize1.subtract(tempSize2, 1));
	}// Of main
}// Of class Size

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100

以前我们使用整数型常量 (第 51 天) 和字符型常量 (第 74 天), 其实还可以有枚举类型. 后面的程序我们才能看到其用法.

package xjx.cnn;

public enum LayerTypeEnum {
	INPUT, CONVOLUTION, SAMPLING, OUTPUT;
}//Of enum LayerTypeEnum

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

第 82 天: 数学操作

在开始今天的任务前先学习一下卷积神经网络CNN的原理。

卷积神经网络的组成

在这里插入图片描述
其中包含了几个主要结构

  • 卷积层(Convolutions)
  • 池化层(Subsampling)
  • 全连接层(Full connection)
  • 激活函数

卷积层

  • 目的
    卷积运算的目的是提取输入的不同特征,某些卷积层可能只能提取一些低级的特征如边缘、线条和角等层级,更多层的网路能从低级特征中迭代提取更复杂的特征。
  • 参数:
    size: 卷积核/过滤器大小,选择有 1 × 1 1\times1 1×1, 3 × 3 3\times3 3×3, 5 × 5 5\times5 5×5
    padding: 零填充,Valid 与Same
    stride: 步长,通常默认为1
  • 计算公式
    输入体积大小 H 1 ∗ W 1 ∗ D 1 H_1*W_1*D_1 H1W1D1
    每个超参数:
    Filter数量 K K K
    Filter大小 F F F
    步长 S S S
    零填充大小 P P P
    输出体积大小 H 2 ∗ W 2 ∗ D 2 H_2*W_2*D_2 H2W2D2
    H 2 = ( H 1 − F + 2 P ) / S + 1 H_2=(H_1-F+2P)/S+1 H2=(H1F+2P)/S+1
    W 2 = ( W 1 − F + 2 P ) / S + 1 W_2=(W_1-F+2P)/S+1 W2=(W1F+2P)/S+1
    D 2 = K D_2=K D2=K

卷积运算过程
假设是一张5 X 5 的单通道图片,通过使用3 X 3 大小的卷积核运算得到一个 3 X 3大小的运算结果(图片像素数值仅供参考)
在这里插入图片描述
我们会发现进行卷积之后的图片变小了,假设N为图片大小,F为卷积核大小
相当于 N − F + 1 = 5 − 3 + 1 = 3 N-F+1 = 5 - 3 + 1 = 3 NF+1=53+1=3

如果换一个卷积核大小或者加入很多层卷积之后,图像可能最后就变成了1 X 1 大小,这不是我们希望看到的结果。并且对于原始图片当中的边缘像素来说,只计算了一遍,二对于中间的像素会有很多次过滤器与之计算,这样导致对边缘信息的丢失。

  • 缺点
    图像变小
    边缘信息丢失

padding-零填充:
因为0在权重乘积和运算中对最终结果不造成影响,也就避免了图片增加了额外的干扰信息。所以在图片像素的最外层加上若干层0值,若一层,记做 p = 1 p=1 p=1.
在刚才的图上增加一层0值,则:
5 + 2 ∗ p − 3 + 1 = 5 5 + 2 * p - 3 + 1=5 5+2p3+1=5
P P P为1,那么最终特征结果为5。
实际上我们可以填充更多的像素,假设为2层,则
5 + 2 ∗ 2 − 3 + 1 = 7 5 + 2 * 2 - 3 + 1 = 7 5+223+1=7
这样得到的观察特征大小比之前图片大小还大。
所以我们对于零填充会有一些选择,该填充多少?

Valid and Same卷积

  • Valid: 不填充,也就是最终大小为
    ( N − F + 1 ) ∗ ( N − F + 1 ) (N - F + 1) * (N - F + 1) (NF+1)(NF+1)
  • Same: 输出大小与原图大小一致,那么 N N N变成了 N + 2 P N + 2P N+2P
    ( N + 2 P − F + 1 ) ∗ ( N + 2 P − F + 1 ) (N + 2P - F + 1) * (N + 2P - F + 1) (N+2PF+1)(N+2PF+1)

那也就意味着,之前大小与之后的大小一样,得出下面的等式
( N + 2 P − F + 1 ) = N (N + 2P - F + 1) = N (N+2PF+1)=N
P = F − 1 2 P = \frac{F -1}{2} P=2F1
所以当知道了卷积核的大小之后,就可以得出要填充多少层像素。

stride-步长:
上面的例子步长为1,将步长修改为2后:
在这里插入图片描述
这样如果以原来的计算公式,那么结果
N + 2 P − F + 1 = 6 + 0 − 3 + 1 = 4 N + 2P - F + 1 = 6 + 0 -3 +1 = 4 N+2PF+1=6+03+1=4
但是移动2个像素才得出一个结果,所以公式变为:
N + 2 P − F 2 + 1 = 1.5 + 1 = 2.5 ​ \frac{N + 2P - F}{2} + 1 = 1.5 + 1 = 2.5​ 2N+2PF+1=1.5+1=2.5
如果相除不是整数的时候,向下取整,为2。这里并没有加上零填充。
所以最终的公式就为:
N + 2 P − F S + 1 \frac{N + 2P - F}{S} + 1 SN+2PF+1
对于输入图片大小为N,过滤器大小为F,步长为S,零填充为P

池化层

池化层主要对卷积层学习到的特征图进行亚采样(subsampling)处理,主要由两种

  • 最大池化:Max Pooling,取窗口内的最大值作为输出
  • 平均池化:Avg Pooling,取窗口内的所有值的均值作为输出

意义在于:

降低了后续网络层的输入维度,缩减模型大小,提高计算速度
提高了Feature Map 的鲁棒性,防止过拟合

全连接层

卷积层+激活层+池化层可以看成是CNN的特征学习/特征提取层,而学习到的特征(Feature Map)最终应用于模型任务(分类、回归):

  • 先对所有 Feature Map 进行扁平化(flatten, 即 reshape 成 1 x N 向量)
  • 再接一个或多个全连接层,进行模型学习

代码

  1. interface Operator 定义了一个算式, 其主要目的是为了矩阵操作时对每个元素都做一遍计算, 所以要看 matrixOp 方法, 以及相应的调用才能明白其作用. 这种算式的写法比较绕, 其优点是灵活, 可以增加代码的复用性. 以 Operator 类型的变量 one_value 为例, 其最终目的是获得 1 − A \mathbf{1} - \mathbf{A} 1A 这种矩阵, 即用 1 减去矩阵的各个元素, 获得新的矩阵.
  2. interface OperatorOnTwo 与上一个类似, 不过它支持两个操作数, 进一步支持两个矩阵, 这样, 矩阵加法、减法就不需要单独写代码了.
  3. matrixOp 被重载了以支持不同的参数列表.
  4. rot180 将矩阵放置 180 度. 通过两次翻转实现.
  5. convnValid 是卷积操作. convnFull 为其逆向操作.
  6. scaleMatrix 是均值池化.
  7. kronecker 是池化的逆向操作.
package xjx.cnn;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;

public class MathUtils {

	/**
	 * 不同按需操作员的界面
	 */
	public interface Operator extends Serializable {
		public double process(double value);
	}

	/**
	 * 一减去值运算符
	 */
	public static final Operator one_value = new Operator() {
		private static final long serialVersionUID = 3752139491940330714L;

		@Override
		public double process(double value) {
			return 1 - value;
		}
	};

	/**
	 * sigmoid 激活函数
	 */
	public static final Operator sigmoid = new Operator() {
		private static final long serialVersionUID = -1952718905019847589L;

		@Override
		public double process(double value) {
			return 1 / (1 + Math.pow(Math.E, -value));
		}
	};

	/**
	 * 具有两个运算符的操作接口
	 */
	interface OperatorOnTwo extends Serializable {
		public double process(double a, double b);
	}

	/**
	 * 加法.
	 */
	public static final OperatorOnTwo plus = new OperatorOnTwo() {
		private static final long serialVersionUID = -6298144029766839945L;

		@Override
		public double process(double a, double b) {
			return a + b;
		}
	};

	/**
	 * 乘法.
	 */
	public static OperatorOnTwo multiply = new OperatorOnTwo() {

		private static final long serialVersionUID = -7053767821858820698L;

		@Override
		public double process(double a, double b) {
			return a * b;
		}
	};

	/**
	 * 减法
	 */
	public static OperatorOnTwo minus = new OperatorOnTwo() {

		private static final long serialVersionUID = 7346065545555093912L;

		@Override
		public double process(double a, double b) {
			return a - b;
		}
	};

	/**
	 *********************** 
	 * 输出矩阵
	 *********************** 
	 */
	public static void printMatrix(double[][] matrix) {
		for (int i = 0; i < matrix.length; i++) {
			String line = Arrays.toString(matrix[i]);
			line = line.replaceAll(", ", "\t");
			System.out.println(line);
		}
		System.out.println();
	}

	/**
	 *********************** 
	 * 将矩阵旋转180度
	 *********************** 
	 */
	public static double[][] rot180(double[][] matrix) {
		matrix = cloneMatrix(matrix);
		int m = matrix.length;
		int n = matrix[0].length;
		for (int i = 0; i < m; i++) {
			for (int j = 0; j < n / 2; j++) {
				double tmp = matrix[i][j];
				matrix[i][j] = matrix[i][n - 1 - j];
				matrix[i][n - 1 - j] = tmp;
			}
		}
		for (int j = 0; j < n; j++) {
			for (int i = 0; i < m / 2; i++) {
				double tmp = matrix[i][j];
				matrix[i][j] = matrix[m - 1 - i][j];
				matrix[m - 1 - i][j] = tmp;
			}
		}
		return matrix;
	}

	private static Random myRandom = new Random(2);

	/**
	 *********************** 
	 * 生成给定大小的随机矩阵。每个值取[-0.005, 0.095]中的值
	 *********************** 
	 */
	public static double[][] randomMatrix(int x, int y, boolean b) {
		double[][] matrix = new double[x][y];
		// int tag = 1;
		for (int i = 0; i < x; i++) {
			for (int j = 0; j < y; j++) {
				matrix[i][j] = (myRandom.nextDouble() - 0.05) / 10;
			}
		}
		return matrix;
	}

	/**
	 *********************** 
	 * 生成具有给定长度的随机数组。每个值取[-0.005, 0.095]中的值
	 *********************** 
	 */
	public static double[] randomArray(int len) {
		double[] data = new double[len];
		for (int i = 0; i < len; i++) {
			//data[i] = myRandom.nextDouble() / 10 - 0.05;
			data[i] = 0;
		}
		return data;
	}

	/**
	 *********************** 
	 * 生成具有批量大小的随机卷积。
	 *********************** 
	 */
	public static int[] randomPerm(int size, int batchSize) {
		Set<Integer> set = new HashSet<Integer>();
		while (set.size() < batchSize) {
			set.add(myRandom.nextInt(size));
		}
		int[] randPerm = new int[batchSize];
		int i = 0;
		for (Integer value : set)
			randPerm[i++] = value;
		return randPerm;
	}

	/**
	 *********************** 
	 * 克隆矩阵。不要直接引用它
	 *********************** 
	 */
	public static double[][] cloneMatrix(final double[][] matrix) {
		final int m = matrix.length;
		int n = matrix[0].length;
		final double[][] outMatrix = new double[m][n];

		for (int i = 0; i < m; i++) {
			for (int j = 0; j < n; j++) {
				outMatrix[i][j] = matrix[i][j];
			}
		}
		return outMatrix;
	}

	/**
	 *********************** 
	 * 在单个操作数上使用给定运算符的矩阵运算
	 *********************** 
	 */
	public static double[][] matrixOp(final double[][] ma, Operator operator) {
		final int m = ma.length;
		int n = ma[0].length;
		for (int i = 0; i < m; i++) {
			for (int j = 0; j < n; j++) {
				ma[i][j] = operator.process(ma[i][j]);
			}
		}
		return ma;
	}

	/**
	 *********************** 
	 * 在两个操作数上使用给定运算符的矩阵运算
	 *********************** 
	 */
	public static double[][] matrixOp(final double[][] ma, final double[][] mb,
			final Operator operatorA, final Operator operatorB, OperatorOnTwo operator) {
		final int m = ma.length;
		int n = ma[0].length;
		if (m != mb.length || n != mb[0].length)
			throw new RuntimeException("ma.length:" + ma.length + "  mb.length:" + mb.length);

		for (int i = 0; i < m; i++) {
			for (int j = 0; j < n; j++) {
				double a = ma[i][j];
				if (operatorA != null)
					a = operatorA.process(a);
				double b = mb[i][j];
				if (operatorB != null)
					b = operatorB.process(b);
				mb[i][j] = operator.process(a, b);
			}
		}
		return mb;
	}

	/**
	 *********************** 
	 * 将矩阵扩展到更大的矩阵(多次)
	 *********************** 
	 */
	public static double[][] kronecker(final double[][] matrix, final Size scale) {
		final int m = matrix.length;
		int n = matrix[0].length;
		final double[][] outMatrix = new double[m * scale.width][n * scale.height];

		for (int i = 0; i < m; i++) {
			for (int j = 0; j < n; j++) {
				for (int ki = i * scale.width; ki < (i + 1) * scale.width; ki++) {
					for (int kj = j * scale.height; kj < (j + 1) * scale.height; kj++) {
						outMatrix[ki][kj] = matrix[i][j];
					}
				}
			}
		}
		return outMatrix;
	}

	/**
	 *********************** 
	 * 缩放矩阵
	 *********************** 
	 */
	public static double[][] scaleMatrix(final double[][] matrix, final Size scale) {
		int m = matrix.length;
		int n = matrix[0].length;
		final int sm = m / scale.width;
		final int sn = n / scale.height;
		final double[][] outMatrix = new double[sm][sn];
		if (sm * scale.width != m || sn * scale.height != n)
			throw new RuntimeException("scale matrix");
		final int size = scale.width * scale.height;
		for (int i = 0; i < sm; i++) {
			for (int j = 0; j < sn; j++) {
				double sum = 0.0;
				for (int si = i * scale.width; si < (i + 1) * scale.width; si++) {
					for (int sj = j * scale.height; sj < (j + 1) * scale.height; sj++) {
						sum += matrix[si][sj];
					}
				}
				outMatrix[i][j] = sum / size;
			}
		}
		return outMatrix;
	}

	/**
	 *********************** 
	 * 完全卷积以获得更大的尺寸。它用于反向传播
	 *********************** 
	 */
	public static double[][] convnFull(double[][] matrix, final double[][] kernel) {
		int m = matrix.length;
		int n = matrix[0].length;
		final int km = kernel.length;
		final int kn = kernel[0].length;
		final double[][] extendMatrix = new double[m + 2 * (km - 1)][n + 2 * (kn - 1)];
		for (int i = 0; i < m; i++) {
			for (int j = 0; j < n; j++) {
				extendMatrix[i + km - 1][j + kn - 1] = matrix[i][j];
			}
		}
		return convnValid(extendMatrix, kernel);
	}

	/**
	 *********************** 
	 * 卷积运算,从给定的矩阵和核,滑动和求和,以获得结果矩阵。它用于向前预测
	 *********************** 
	 */
	public static double[][] convnValid(final double[][] matrix, double[][] kernel) {
		// kernel = rot180(kernel);
		int m = matrix.length;
		int n = matrix[0].length;
		final int km = kernel.length;
		final int kn = kernel[0].length;
		int kns = n - kn + 1;
		final int kms = m - km + 1;
		final double[][] outMatrix = new double[kms][kns];

		for (int i = 0; i < kms; i++) {
			for (int j = 0; j < kns; j++) {
				double sum = 0.0;
				for (int ki = 0; ki < km; ki++) {
					for (int kj = 0; kj < kn; kj++)
						sum += matrix[i + ki][j + kj] * kernel[ki][kj];
				}
				outMatrix[i][j] = sum;

			}
		}
		return outMatrix;
	}

	/**
	 *********************** 
	 * 张量上的卷积
	 *********************** 
	 */
	public static double[][] convnValid(final double[][][][] matrix, int mapNoX,
			double[][][][] kernel, int mapNoY) {
		int m = matrix.length;
		int n = matrix[0][mapNoX].length;
		int h = matrix[0][mapNoX][0].length;
		int km = kernel.length;
		int kn = kernel[0][mapNoY].length;
		int kh = kernel[0][mapNoY][0].length;
		int kms = m - km + 1;
		int kns = n - kn + 1;
		int khs = h - kh + 1;
		if (matrix.length != kernel.length)
			throw new RuntimeException("length");
		final double[][][] outMatrix = new double[kms][kns][khs];
		for (int i = 0; i < kms; i++) {
			for (int j = 0; j < kns; j++)
				for (int k = 0; k < khs; k++) {
					double sum = 0.0;
					for (int ki = 0; ki < km; ki++) {
						for (int kj = 0; kj < kn; kj++)
							for (int kk = 0; kk < kh; kk++) {
								sum += matrix[i + ki][mapNoX][j + kj][k + kk]
										* kernel[ki][mapNoY][kj][kk];
							}
					}
					outMatrix[i][j][k] = sum;
				}
		}
		return outMatrix[0];
	}

	/**
	 *********************** 
	 * sigmod 操作
	 *********************** 
	 */
	public static double sigmod(double x) {
		return 1 / (1 + Math.pow(Math.E, -x));
	}

	/**
	 *********************** 
	 * 求矩阵的所有值之和
	 *********************** 
	 */
	public static double sum(double[][] error) {
		int m = error.length;
		int n = error[0].length;
		double sum = 0.0;
		for (int i = 0; i < m; i++) {
			for (int j = 0; j < n; j++) {
				sum += error[i][j];
			}
		}
		return sum;
	}

	/**
	 *********************** 
	 * Ad hoc sum.
	 *********************** 
	 */
	public static double[][] sum(double[][][][] errors, int j) {
		int m = errors[0][j].length;
		int n = errors[0][j][0].length;
		double[][] result = new double[m][n];
		for (int mi = 0; mi < m; mi++) {
			for (int nj = 0; nj < n; nj++) {
				double sum = 0;
				for (int i = 0; i < errors.length; i++)
					sum += errors[i][j][mi][nj];
				result[mi][nj] = sum;
			}
		}
		return result;
	}

	/**
	 *********************** 
	 * 获取最终分类的最大值索引
	 *********************** 
	 */
	public static int getMaxIndex(double[] out) {
		double max = out[0];
		int index = 0;
		for (int i = 1; i < out.length; i++)
			if (out[i] > max) {
				max = out[i];
				index = i;
			}
		return index;
	}
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431

第 83 天: 数学操作 (续)

昨天的操作比较多, 今天可以自己写些代码来测试下. 例如, 测试用例最好是 8*8 之内的矩阵, 这样比较容易验证正确性.
部分测试代码:

public static void main(String args[]) {
	MathUtils tempmathUtils = new MathUtils();
	
	double[][] matrix = new double[8][8];
	matrix = tempmathUtils.randomMatrix(8, 8, true);
	printMatrix(matrix);
	//matrix = rot180(matrix);
	//printMatrix(matrix);
	
	double[] Array = new double[8];
	Array = tempmathUtils.randomArray(8);
	
	int[] Perm = new int[10];
	Perm = tempmathUtils.randomPerm(20,10);
	
	Operator aa;
		double[][] tempmatrixOp = new double[8][8];
	tempmatrixOp = tempmathUtils.matrixOp(tempmatrixOp,aa);
	printMatrix(matrix);
	
	Operator a1,b1;
	OperatorOnTwo c1;
	double[][] tempmatrixOpa = new double[8][8];
	double[][] tempmatrixOpb = new double[8][8];
	tempmatrixOpa = matrixOp(tempmatrixOpa, tempmatrixOpb, a1, b1, c1);
	printMatrix(matrix);
	//System.out.println();
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/715914
推荐阅读
相关标签
  

闽ICP备14008679号