赞
踩
好久没写 blog 了,没有坚持住,心中满满的负罪感!!!
上周一时冲动了,决定自己 code 一下 mlp
(多层感知机)。最后的测试部分使用它来识别手写数字,也就是在 MNIST
数据集上训练并测试效果。在读取 MNIST
数据集时本打算使用轮子,可并没找到使用 Java
创造的轮子。于是,根据官网的存储格式说明自己写了一个。
遂得此文,望可抛砖引玉~~(废话少说!)
MachineLearing
中非常出名的数据集,它以二进制的形式存储了每个手写数字的像素及标签。下面是可视化后的一个样例图。[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number // 魔数,就像 java 类文件中的 “CAFEBABE”。可视为一种验证,其实没有~~
0004 32 bit integer 60000 number of images // 表明一共有 60000 中样例
0008 32 bit integer 28 number of rows // 一行含有的像素点数
0012 32 bit integer 28 number of columns // 一列含有的像素点数
0016 unsigned byte ?? pixel // 对应像素点的值(0 ~ 255)
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first) // 同上
0004 32 bit integer 60000 number of items // 同上
0008 unsigned byte ?? label // 对应样本的标签,即对应图像中的手写数字是几(0 ~ 9)
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
import java.io.BufferedInputStream; import java.io.FileInputStream; import java.io.IOException; public class MnistRead { public static final String TRAIN_IMAGES_FILE = "data/mnist/train-images.idx3-ubyte"; public static final String TRAIN_LABELS_FILE = "data/mnist/train-labels.idx1-ubyte"; public static final String TEST_IMAGES_FILE = "data/mnist/t10k-images.idx3-ubyte"; public static final String TEST_LABELS_FILE = "data/mnist/t10k-labels.idx1-ubyte"; /** * change bytes into a hex string. * * @param bytes bytes * @return the returned hex string */ public static String bytesToHex(byte[] bytes) { StringBuffer sb = new StringBuffer(); for (int i = 0; i < bytes.length; i++) { String hex = Integer.toHexString(bytes[i] & 0xFF); if (hex.length() < 2) { sb.append(0); } sb.append(hex); } return sb.toString(); } /** * get images of 'train' or 'test' * * @param fileName the file of 'train' or 'test' about image * @return one row show a `picture` */ public static double[][] getImages(String fileName) { double[][] x = null; try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000803".equals(bytesToHex(bytes))) { // 读取魔数 throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); int number = Integer.parseInt(bytesToHex(bytes), 16); // 读取样本总数 bin.read(bytes, 0, 4); int xPixel = Integer.parseInt(bytesToHex(bytes), 16); // 读取每行所含像素点数 bin.read(bytes, 0, 4); int yPixel = Integer.parseInt(bytesToHex(bytes), 16); // 读取每列所含像素点数 x = new double[number][xPixel * yPixel]; for (int i = 0; i < number; i++) { double[] element = new double[xPixel * yPixel]; for (int j = 0; j < xPixel * yPixel; j++) { element[j] = bin.read(); // 逐一读取像素值 // normalization // element[j] = bin.read() / 255.0; } x[i] = element; } } } catch (IOException e) { throw new RuntimeException(e); } return x; } /** * get labels of `train` or `test` * * @param fileName the file of 'train' or 'test' about label * @return */ public static double[] getLabels(String fileName) { double[] y = null; try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000801".equals(bytesToHex(bytes))) { throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); int number = Integer.parseInt(bytesToHex(bytes), 16); y = new double[number]; for (int i = 0; i < number; i++) { y[i] = bin.read(); } } } catch (IOException e) { throw new RuntimeException(e); } return y; } public static void main(String[] args) { double[][] images = getImages(TRAIN_IMAGES_FILE); double[] labels = getLabels(TRAIN_LABELS_FILE); double[][] images = getImages(TEST_IMAGES_FILE); double[] labels = getLabels(TEST_LABELS_FILE); System.out.println(); } }
/** * draw a gray picture and the image format is JPEG. * * @param pixelValues pixelValues and ordered by column. * @param width width * @param high high * @param fileName image saved file. */ public static void drawGrayPicture(int[] pixelValues, int width, int high, String fileName) throws IOException { BufferedImage bufferedImage = new BufferedImage(width, high, BufferedImage.TYPE_INT_RGB); for (int i = 0; i < width; i++) { for (int j = 0; j < high; j++) { int pixel = 255 - pixelValues[i * high + j]; int value = pixel + (pixel << 8) + (pixel << 16); // r = g = b 时,正好为灰度 bufferedImage.setRGB(j, i, value); } } ImageIO.write(bufferedImage, "JPEG", new File(fileName)); }
上面的读取过程还是很简单的。
想分享一下自己 code 的 bp
(反向传播)。希望有时间~~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。