当前位置:   article > 正文

使用DL4J对CWRU数据集进行简单分类_cwru 分类

cwru 分类

目录

0. 前言

1. 关于CWRU数据集

2. 数据读取

3.数据预处理

4. 训练


0. 前言

最近开始搞智能故障诊断方面的工作,一上来面对的就是要各种炼丹。虽然众所周知在炼丹方面是python比较擅长,但由于本人已经写了不少年java形成了路径依赖,电脑上早就装好了dl4j的环境,本着能凑合用就绝不换引擎的原则,决定拿这玩意继续对付到不能用为止。

dl4j的生态其实并不是很好,文档都不是很全;在国内生态就更差,几乎还没见到有人用过。这次写这篇文章呢,也没寻思给谁看或者能帮到谁,就当是给自己做个备忘好了。安装和基础教程方面这里推荐一下b站寒沧的教学,当年第一次安的时候也是帮了我大忙。地址:

的个人空间_哔哩哔哩_Bilibili

1. 关于CWRU数据集

这是凯斯西储大学提供的一个数据集,在故障诊断领域属于入门级的数据集,大概相当于MNIST的地位。特征非常明显,分类也非常简单。下载地址:

Download a Data File | Case School of Engineering | Case Western Reserve University

因为数据太琐碎了也并不是全都用了,这里放上我自己用的,忘了从哪下的其中一部分数据集:

链接: https://pan.baidu.com/s/1fw3bCLV7qu1ZRQxVvoxJIw 提取码: 4v24 

文件说明:(别处抄的)

文件为Matlab格式

每个文件包含风扇和驱动端振动数据,以及电机转速,文件中文件变量命名如下:

DE - drive end accelerometer data 驱动端振动数据

FE - fan end accelerometer data 风扇端振动数据

BA - base accelerometer data 基座振动数据

time - time series data 时间序列数据

RPM- rpm during testing 单位转每分钟 除以60则为转频

数据采集频率分别为

数据集A:在12Khz采样频率下的驱动端轴承故障数据

数据集B:在48Khz采样频率下的驱动端轴承故障数据

数据集C:在12Khz采样频率下的风扇端轴承故障数据

数据集D:以及正常的轴承数据(采样频率应该是48k的)

数据集B解读:在48Khz采样频率下的驱动端轴承故障直径又分为0.007英寸、0.014英寸、0.028英寸三种类别,每种故障下负载又分为0马力、1马力、2马力、3马力。在每种故障的每种马力下有轴承内圈故障、轴承滚动体故障、轴承外环故障(由于轴承外环位置一般比较固定,因此外环故障又分为3点钟、6点钟和12点钟三种类别)。

2. 数据读取

因为上面下载的源文件是matlab的格式,打开看了一下发现是二进制的,不像csv或者json那样可以很方便地自己写parser去读,所以要用别的库,这里选择JMatio。

由于DL4J必须使用Java10还是11以上的版本(反正我用的Java17),而Jar版本的JMatio因为太古老了而对高版本Java不兼容,如果强行运行的话会报错:

java.lang.reflect.InaccessibleObjectException: Unable to make public jdk.internal.ref.Cleaner java.nio.DirectByteBuffer.cleaner() accessible。

看网上其他项目的解决办法是在JVM里加启动参数:

--add-opens=java.base/java.nio=ALL-UNNAMED

但对这个库仍不好使,会继续报错:

Exception in thread "main" java.lang.NoClassDefFoundError: sun/misc/Cleaner

搞得我很是头大,差点就因为这点小事弃坑了。鼓捣了半天,最后偶然间在翻这个项目的github的时候,发现两年前的一个commit修复了对高版本Java的支持。最后的解决办法是直接把这个版本的源码扔进项目里。当然也可以自己把代码打个包,以及修复后的版本应该在Maven也是有的,我这里就懒得弄了,毕竟只是个学习项目,能凑合使就得了。地址:

GitHub - gradusnikov/jmatio: JMatIO - Matlab's MAT-file I/O in JAVA

读数据的方式也很简单暴力,直接遍历文件夹下的所有文件,然后按文件名填信息就好了。

  1. public class CWRUDataParser {
  2. public static void parse() throws Exception {
  3. var path = "你的path";
  4. for (String fname : new File(path).list()) {
  5. //System.out.println(fname);
  6. CWRUData d = new CWRUData();
  7. d.name = fname;
  8. CWRUDataManager.dataList.add(d);
  9. MatFileReader reader = new MatFileReader(path + "\\" + fname);
  10. var content = reader.getContent();
  11. if (fname.contains("_B")) {
  12. d.err_type = "Ball";
  13. } else if (fname.contains("_IR")) {
  14. d.err_type = "IR";
  15. } else if (fname.contains("_OR")) {
  16. d.err_type = "OR";
  17. } else {
  18. d.err_type = "Normal";
  19. }
  20. if (fname.contains("028")) {
  21. d.depth = 28;
  22. } else if (fname.contains("021")) {
  23. d.depth = 21;
  24. } else if (fname.contains("014")) {
  25. d.depth = 14;
  26. } else if (fname.contains("007")) {
  27. d.depth = 7;
  28. }
  29. if (fname.contains("_0_")) {
  30. d.load = 0;
  31. } else if (fname.contains("_1_")) {
  32. d.load = 1;
  33. } else if (fname.contains("_2_")) {
  34. d.load = 2;
  35. } else if (fname.contains("_3_")) {
  36. d.load = 3;
  37. }
  38. if (fname.contains("@3")) {
  39. d.pos = 3;
  40. } else if (fname.contains("@6")) {
  41. d.pos = 6;
  42. } else if (fname.contains("@12")) {
  43. d.pos = 12;
  44. }
  45. for (String key : content.keySet()) {
  46. var value = content.get(key);
  47. if (key.contains("DE")) {
  48. d.DE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
  49. //System.out.println(d.DE.length);
  50. } else if (key.contains("FE")) {
  51. d.FE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
  52. //System.out.println(d.FE.length);
  53. } else if (key.contains("BA")) {
  54. d.BA = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
  55. //System.out.println(d.BA.length);
  56. } else if (key.contains("RPM")){
  57. d.rpm = Double.valueOf(value.contentToString().split("=")[1]);
  58. //System.out.println(d.rpm);
  59. }
  60. }
  61. }
  62. }
  63. public static double[] toDoubleArray(MLArray ma) {
  64. MLDouble md = (MLDouble) ma;
  65. int m = md.getM();
  66. double[] data = new double[m];
  67. for (int i = 0; i < m; i++) {
  68. data[i] = md.get(i, 0);
  69. }
  70. return data;
  71. }
  72. public static double[] toDoubleArray4x(MLArray ma) {
  73. MLDouble md = (MLDouble) ma;
  74. int m = md.getM();
  75. double[] data = new double[m/4];
  76. for (int i = 0; i < m/4; i++) {
  77. data[i] = (md.get(4*i, 0)+md.get(4*i+1, 0)+md.get(4*i+2, 0)+md.get(4*i+3, 0))/4;
  78. }
  79. return data;
  80. }
  81. }

其中CWRUData是我自己封装的一个简单结构:

  1. public class CWRUData {
  2. public String name;
  3. public String err_type;
  4. public int depth;
  5. public int pos = 0;
  6. public int load;
  7. public double[] DE;
  8. public double[] FE;
  9. public double[] BA;
  10. public double rpm;
  11. public List<CWRUBlock> blocks = new ArrayList();
  12. public void cut(int size, int num) {
  13. for (int current = 0; current < DE.length-size ;current += DE.length/num) {
  14. double[] blockDE = new double[size];
  15. for (int i = 0; i < size; i++) {
  16. blockDE[i] = DE[i + current];
  17. }
  18. CWRUBlock block = new CWRUBlock(size, this);
  19. block.DE = blockDE;
  20. blocks.add(block);
  21. }
  22. }
  23. public int type() {
  24. if (err_type.equals("Ball")) {
  25. return depth / 7;
  26. }
  27. if (err_type.equals("IR")) {
  28. return 3 + depth / 7;
  29. }
  30. if (err_type.equals("OR")) {
  31. return 6 + depth / 7;
  32. }
  33. return 0;
  34. }
  35. public void print() {
  36. System.out.println("length:" + DE.length);
  37. System.out.println("block num:" + blocks.size());
  38. }
  39. }

3.数据预处理

补充一下上面代码里没提到的东西。一是关于toDoubleArray4x():由于故障数据的采样率是12k,而正常数据是48k,为了保证二者频率一样,因此在存入正常数据的时候使用的是四合一平均池化的toDoubleArray4x()。二是type()的作用:是把数据集根据故障类型分为10类,正常的一类,故障的三类根据depth为7/14/21每个又分为三类,1+3*3=10。由于depth为28的数据不全,舍弃。OR错误类型的数据只使用位置为6的,其他位置的数据舍弃。

CWRU数据集提供的数据是一段很长很长的离散采样序列,需要用滑窗切成一段一段的才能处理。然后分析的时候一般只用DE的数据,因为其他的好像不全。切的方法上面已经给出了,用的时候只需要:

CWRUDataManager.dataList.forEach(d->d.cut(512, 400));

其中CWRUBlock也是自己封装的一个数据类型,代码:

  1. public class CWRUBlock {
  2. public final int size;
  3. public final CWRUData source;
  4. public double[] DE;
  5. public double[] FE;
  6. public double[] BA;
  7. public CWRUBlock(int size, CWRUData source) {
  8. this.size = size;
  9. this.source = source;
  10. }
  11. }

全都处理完之后就可以做dl4j的DataSet了。代码:

  1. public static DataSet genGenericDataSet1() {
  2. List<CWRUBlock> blocks = new ArrayList();
  3. for (var data: dataList) {
  4. if (data.pos != 0 && data.pos != 6) continue;
  5. if (data.depth == 28) continue;
  6. data.blocks.forEach(blocks::add);
  7. }
  8. Collections.shuffle(blocks);
  9. INDArray[] input = blocks.stream().map(b->Nd4j.create(b.DE, 1, b.DE.length)).toArray(INDArray[]::new);
  10. INDArray inputs = Nd4j.vstack(input);
  11. INDArray[] output = blocks.stream().map(b->genOutputFromType(b.source.type())).toArray(INDArray[]::new);
  12. INDArray outputs = Nd4j.vstack(output);
  13. DataSet dataSet = new DataSet(inputs, outputs);
  14. return dataSet;
  15. }

这里为什么不用dl4j自带的shuffle,而要使用Collections.shuffle呢?这是因为我也不知道为啥dl4j自带的shuffle会直接崩掉jvm。。真是神奇的框架捏。

4. 训练

网络结构(一个非常简单的多层感知机):

  1. public static MultiLayerConfiguration CWRUANN() {
  2. MultiLayerConfiguration builder = new NeuralNetConfiguration.Builder()
  3. .seed(19260817L)
  4. .updater(new Sgd(0.01))
  5. .weightInit(WeightInit.XAVIER)
  6. .list()
  7. .layer(new DenseLayer.Builder().nIn(512).nOut(128)
  8. .activation(Activation.RELU)
  9. .build())
  10. .layer(new DenseLayer.Builder().nIn(128).nOut(32)
  11. .activation(Activation.RELU)
  12. .build())
  13. .layer(new DenseLayer.Builder().nIn(32).nOut(10)
  14. .activation(Activation.RELU)
  15. .build())
  16. .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
  17. .activation(Activation.SOFTMAX)
  18. .nIn(10).nOut(10).build())
  19. .build();
  20. return builder;
  21. }

开启监视器,以及训练:

  1. public static void ANN(DataSet train, DataSet test) throws Exception {
  2. MultiLayerNetwork model = new MultiLayerNetwork(NetFactory.CWRUANN());
  3. model.init();
  4. UIServer server = UIServer.getInstance();
  5. server.enableRemoteListener();
  6. StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
  7. model.setListeners(new StatsListener(remoteUIRouter));
  8. DataSetIterator iterator = getIter(train, 20);
  9. for (int x = 0; x< 10000; x++) {
  10. if (!iterator.hasNext()) {
  11. iterator = getIter(train, 20);
  12. }
  13. model.fit(iterator);
  14. if (x % 10 == 0) {
  15. model.save(new File("模型保存路径" + x + ".zip"), true);
  16. Evaluation eval = new Evaluation(10);
  17. INDArray output = model.output(test.getFeatures());
  18. eval.eval(test.getLabels(), output);
  19. log.info(eval.stats());
  20. }
  21. }
  22. }
  23. private static DataSetIterator getIter(final DataSet set, final int batchSize) {
  24. final List<DataSet> list = set.asList();
  25. Collections.shuffle(list, new Random());
  26. return new ListDataSetIterator(list,batchSize);
  27. }

训练结果:

  1. ========================Evaluation Metrics========================
  2. # of classes: 10
  3. Accuracy: 0.9831
  4. Precision: 0.9833
  5. Recall: 0.9834
  6. F1 Score: 0.9833
  7. Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
  8. =========================Confusion Matrix=========================
  9. 0 1 2 3 4 5 6 7 8 9
  10. -----------------------------------------
  11. 157 0 0 0 0 0 0 0 0 0 | 0 = 0
  12. 0 141 0 4 0 0 0 0 0 0 | 1 = 1
  13. 1 0 177 0 1 2 0 0 0 0 | 2 = 2
  14. 0 1 1 164 0 2 0 0 1 0 | 3 = 3
  15. 0 0 1 0 163 1 1 0 0 0 | 4 = 4
  16. 0 4 0 1 1 172 0 0 0 0 | 5 = 5
  17. 0 0 0 0 0 0 145 0 0 0 | 6 = 6
  18. 0 0 0 0 0 0 0 159 0 1 | 7 = 7
  19. 0 0 0 0 0 0 0 0 161 0 | 8 = 8
  20. 0 0 0 0 0 0 0 4 0 133 | 9 = 9
  21. Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
  22. ==================================================================

可以看出效果是非常不错的,真是简单的数据集捏。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/630851
推荐阅读
相关标签
  

闽ICP备14008679号