当前位置:   article > 正文

使用 Java 读取 MNIST 数据集_java mnist

java mnist

使用 Java 读取 Mnist 数据集

0. 前言

好久没写 blog 了,没有坚持住,心中满满的负罪感!!!

上周一时冲动了,决定自己 code 一下 mlp (多层感知机)。最后的测试部分使用它来识别手写数字,也就是在 MNIST 数据集上训练并测试效果。在读取 MNIST 数据集时本打算使用轮子,可并没找到使用 Java 创造的轮子。于是,根据官网的存储格式说明自己写了一个。

遂得此文,望可抛砖引玉~~(废话少说!)

1. MNIST 数据集
  • THE MNIST DATABASE of handwritten digits
    • MachineLearing 中非常出名的数据集,它以二进制的形式存储了每个手写数字的像素及标签。下面是可视化后的一个样例图。
      手写数字 0 的样例图
    • 其他信息详见官网(点击上面的小标题可以直接进入)
  • 数据集的格式
    • IMAGE FILE (以 train-images-idx3-ubyte 为例)
      [offset] [type]          [value]          [description] 
      0000     32 bit integer  0x00000803(2051) magic number          // 魔数,就像 java 类文件中的 “CAFEBABE”。可视为一种验证,其实没有~~
      0004     32 bit integer  60000            number of images      // 表明一共有 60000 中样例
      0008     32 bit integer  28               number of rows        // 一行含有的像素点数
      0012     32 bit integer  28               number of columns     // 一列含有的像素点数
      0016     unsigned byte   ??               pixel                 // 对应像素点的值(0 ~ 255)
      0017     unsigned byte   ??               pixel 
      ........ 
      xxxx     unsigned byte   ??               pixel
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
    • LABEL FILE (以 train-labels-idx1-ubyte 为例)
      [offset] [type]          [value]          [description] 
      0000     32 bit integer  0x00000801(2049) magic number (MSB first)  // 同上
      0004     32 bit integer  60000            number of items           // 同上
      0008     unsigned byte   ??               label                     // 对应样本的标签,即对应图像中的手写数字是几(0 ~ 9)
      0009     unsigned byte   ??               label 
      ........ 
      xxxx     unsigned byte   ??               label
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
2. 代码
  • 读取数据集代码
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;

public class MnistRead {

    public static final String TRAIN_IMAGES_FILE = "data/mnist/train-images.idx3-ubyte";
    public static final String TRAIN_LABELS_FILE = "data/mnist/train-labels.idx1-ubyte";
    public static final String TEST_IMAGES_FILE = "data/mnist/t10k-images.idx3-ubyte";
    public static final String TEST_LABELS_FILE = "data/mnist/t10k-labels.idx1-ubyte";

    /**
     * change bytes into a hex string.
     *
     * @param bytes bytes
     * @return the returned hex string
     */
    public static String bytesToHex(byte[] bytes) {
        StringBuffer sb = new StringBuffer();
        for (int i = 0; i < bytes.length; i++) {
            String hex = Integer.toHexString(bytes[i] & 0xFF);
            if (hex.length() < 2) {
                sb.append(0);
            }
            sb.append(hex);
        }
        return sb.toString();
    }

    /**
     * get images of 'train' or 'test'
     *
     * @param fileName the file of 'train' or 'test' about image
     * @return one row show a `picture`
     */
    public static double[][] getImages(String fileName) {
        double[][] x = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000803".equals(bytesToHex(bytes))) {                        // 读取魔数
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);           // 读取样本总数
                bin.read(bytes, 0, 4);
                int xPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每行所含像素点数
                bin.read(bytes, 0, 4);
                int yPixel = Integer.parseInt(bytesToHex(bytes), 16);           // 读取每列所含像素点数
                x = new double[number][xPixel * yPixel];
                for (int i = 0; i < number; i++) {
                    double[] element = new double[xPixel * yPixel];
                    for (int j = 0; j < xPixel * yPixel; j++) {
                        element[j] = bin.read();                                // 逐一读取像素值
                        // normalization
//                        element[j] = bin.read() / 255.0;
                    }
                    x[i] = element;
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return x;
    }

    /**
     * get labels of `train` or `test`
     *
     * @param fileName the file of 'train' or 'test' about label
     * @return
     */
    public static double[] getLabels(String fileName) {
        double[] y = null;
        try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) {
            byte[] bytes = new byte[4];
            bin.read(bytes, 0, 4);
            if (!"00000801".equals(bytesToHex(bytes))) {
                throw new RuntimeException("Please select the correct file!");
            } else {
                bin.read(bytes, 0, 4);
                int number = Integer.parseInt(bytesToHex(bytes), 16);
                y = new double[number];
                for (int i = 0; i < number; i++) {
                    y[i] = bin.read();
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        return y;
    }

    public static void main(String[] args) {
        double[][] images = getImages(TRAIN_IMAGES_FILE);
        double[] labels = getLabels(TRAIN_LABELS_FILE);

        double[][] images = getImages(TEST_IMAGES_FILE);
        double[] labels = getLabels(TEST_LABELS_FILE);

        System.out.println();
    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 显示图像代码
/**
 * draw a gray picture and the image format is JPEG.
 *
 * @param pixelValues pixelValues and ordered by column.
 * @param width       width
 * @param high        high
 * @param fileName    image saved file.
 */
public static void drawGrayPicture(int[] pixelValues, int width, int high, String fileName) throws IOException {
    BufferedImage bufferedImage = new BufferedImage(width, high, BufferedImage.TYPE_INT_RGB);
    for (int i = 0; i < width; i++) {
        for (int j = 0; j < high; j++) {
            int pixel = 255 - pixelValues[i * high + j];
            int value = pixel + (pixel << 8) + (pixel << 16);   // r = g = b 时,正好为灰度
            bufferedImage.setRGB(j, i, value);
        }
    }
    ImageIO.write(bufferedImage, "JPEG", new File(fileName));
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
3. 还有什么

上面的读取过程还是很简单的。

想分享一下自己 code 的 bp(反向传播)。希望有时间~~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/360696
推荐阅读
相关标签
  

闽ICP备14008679号