赞
踩
目录
最近开始搞智能故障诊断方面的工作,一上来面对的就是要各种炼丹。虽然众所周知在炼丹方面是python比较擅长,但由于本人已经写了不少年java形成了路径依赖,电脑上早就装好了dl4j的环境,本着能凑合用就绝不换引擎的原则,决定拿这玩意继续对付到不能用为止。
dl4j的生态其实并不是很好,文档都不是很全;在国内生态就更差,几乎还没见到有人用过。这次写这篇文章呢,也没寻思给谁看或者能帮到谁,就当是给自己做个备忘好了。安装和基础教程方面这里推荐一下b站寒沧的教学,当年第一次安的时候也是帮了我大忙。地址:
这是凯斯西储大学提供的一个数据集,在故障诊断领域属于入门级的数据集,大概相当于MNIST的地位。特征非常明显,分类也非常简单。下载地址:
Download a Data File | Case School of Engineering | Case Western Reserve University
因为数据太琐碎了也并不是全都用了,这里放上我自己用的,忘了从哪下的其中一部分数据集:
链接: https://pan.baidu.com/s/1fw3bCLV7qu1ZRQxVvoxJIw 提取码: 4v24
文件说明:(别处抄的)
文件为Matlab格式
每个文件包含风扇和驱动端振动数据,以及电机转速,文件中文件变量命名如下:
DE - drive end accelerometer data 驱动端振动数据
FE - fan end accelerometer data 风扇端振动数据
BA - base accelerometer data 基座振动数据
time - time series data 时间序列数据
RPM- rpm during testing 单位转每分钟 除以60则为转频
数据采集频率分别为:
数据集A:在12Khz采样频率下的驱动端轴承故障数据
数据集B:在48Khz采样频率下的驱动端轴承故障数据
数据集C:在12Khz采样频率下的风扇端轴承故障数据
数据集D:以及正常的轴承数据(采样频率应该是48k的)
数据集B解读:在48Khz采样频率下的驱动端轴承故障直径又分为0.007英寸、0.014英寸、0.028英寸三种类别,每种故障下负载又分为0马力、1马力、2马力、3马力。在每种故障的每种马力下有轴承内圈故障、轴承滚动体故障、轴承外环故障(由于轴承外环位置一般比较固定,因此外环故障又分为3点钟、6点钟和12点钟三种类别)。
因为上面下载的源文件是matlab的格式,打开看了一下发现是二进制的,不像csv或者json那样可以很方便地自己写parser去读,所以要用别的库,这里选择JMatio。
由于DL4J必须使用Java10还是11以上的版本(反正我用的Java17),而Jar版本的JMatio因为太古老了而对高版本Java不兼容,如果强行运行的话会报错:
java.lang.reflect.InaccessibleObjectException: Unable to make public jdk.internal.ref.Cleaner java.nio.DirectByteBuffer.cleaner() accessible。
看网上其他项目的解决办法是在JVM里加启动参数:
--add-opens=java.base/java.nio=ALL-UNNAMED
但对这个库仍不好使,会继续报错:
Exception in thread "main" java.lang.NoClassDefFoundError: sun/misc/Cleaner
搞得我很是头大,差点就因为这点小事弃坑了。鼓捣了半天,最后偶然间在翻这个项目的github的时候,发现两年前的一个commit修复了对高版本Java的支持。最后的解决办法是直接把这个版本的源码扔进项目里。当然也可以自己把代码打个包,以及修复后的版本应该在Maven也是有的,我这里就懒得弄了,毕竟只是个学习项目,能凑合使就得了。地址:
GitHub - gradusnikov/jmatio: JMatIO - Matlab's MAT-file I/O in JAVA
读数据的方式也很简单暴力,直接遍历文件夹下的所有文件,然后按文件名填信息就好了。
- public class CWRUDataParser {
-
- public static void parse() throws Exception {
- var path = "你的path";
- for (String fname : new File(path).list()) {
- //System.out.println(fname);
- CWRUData d = new CWRUData();
- d.name = fname;
- CWRUDataManager.dataList.add(d);
-
- MatFileReader reader = new MatFileReader(path + "\\" + fname);
- var content = reader.getContent();
-
- if (fname.contains("_B")) {
- d.err_type = "Ball";
- } else if (fname.contains("_IR")) {
- d.err_type = "IR";
- } else if (fname.contains("_OR")) {
- d.err_type = "OR";
- } else {
- d.err_type = "Normal";
- }
-
- if (fname.contains("028")) {
- d.depth = 28;
- } else if (fname.contains("021")) {
- d.depth = 21;
- } else if (fname.contains("014")) {
- d.depth = 14;
- } else if (fname.contains("007")) {
- d.depth = 7;
- }
-
- if (fname.contains("_0_")) {
- d.load = 0;
- } else if (fname.contains("_1_")) {
- d.load = 1;
- } else if (fname.contains("_2_")) {
- d.load = 2;
- } else if (fname.contains("_3_")) {
- d.load = 3;
- }
-
- if (fname.contains("@3")) {
- d.pos = 3;
- } else if (fname.contains("@6")) {
- d.pos = 6;
- } else if (fname.contains("@12")) {
- d.pos = 12;
- }
-
- for (String key : content.keySet()) {
- var value = content.get(key);
- if (key.contains("DE")) {
- d.DE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
- //System.out.println(d.DE.length);
- } else if (key.contains("FE")) {
- d.FE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
- //System.out.println(d.FE.length);
- } else if (key.contains("BA")) {
- d.BA = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
- //System.out.println(d.BA.length);
- } else if (key.contains("RPM")){
- d.rpm = Double.valueOf(value.contentToString().split("=")[1]);
- //System.out.println(d.rpm);
- }
- }
- }
- }
-
- public static double[] toDoubleArray(MLArray ma) {
- MLDouble md = (MLDouble) ma;
- int m = md.getM();
- double[] data = new double[m];
- for (int i = 0; i < m; i++) {
- data[i] = md.get(i, 0);
- }
- return data;
- }
-
- public static double[] toDoubleArray4x(MLArray ma) {
- MLDouble md = (MLDouble) ma;
- int m = md.getM();
- double[] data = new double[m/4];
- for (int i = 0; i < m/4; i++) {
- data[i] = (md.get(4*i, 0)+md.get(4*i+1, 0)+md.get(4*i+2, 0)+md.get(4*i+3, 0))/4;
- }
- return data;
- }
- }
其中CWRUData是我自己封装的一个简单结构:
- public class CWRUData {
- public String name;
- public String err_type;
- public int depth;
- public int pos = 0;
- public int load;
-
- public double[] DE;
- public double[] FE;
- public double[] BA;
-
- public double rpm;
-
- public List<CWRUBlock> blocks = new ArrayList();
-
- public void cut(int size, int num) {
- for (int current = 0; current < DE.length-size ;current += DE.length/num) {
- double[] blockDE = new double[size];
- for (int i = 0; i < size; i++) {
- blockDE[i] = DE[i + current];
- }
- CWRUBlock block = new CWRUBlock(size, this);
- block.DE = blockDE;
- blocks.add(block);
- }
- }
-
- public int type() {
- if (err_type.equals("Ball")) {
- return depth / 7;
- }
- if (err_type.equals("IR")) {
- return 3 + depth / 7;
- }
- if (err_type.equals("OR")) {
- return 6 + depth / 7;
- }
- return 0;
- }
-
- public void print() {
- System.out.println("length:" + DE.length);
- System.out.println("block num:" + blocks.size());
- }
- }
补充一下上面代码里没提到的东西。一是关于toDoubleArray4x():由于故障数据的采样率是12k,而正常数据是48k,为了保证二者频率一样,因此在存入正常数据的时候使用的是四合一平均池化的toDoubleArray4x()。二是type()的作用:是把数据集根据故障类型分为10类,正常的一类,故障的三类根据depth为7/14/21每个又分为三类,1+3*3=10。由于depth为28的数据不全,舍弃。OR错误类型的数据只使用位置为6的,其他位置的数据舍弃。
CWRU数据集提供的数据是一段很长很长的离散采样序列,需要用滑窗切成一段一段的才能处理。然后分析的时候一般只用DE的数据,因为其他的好像不全。切的方法上面已经给出了,用的时候只需要:
CWRUDataManager.dataList.forEach(d->d.cut(512, 400));
其中CWRUBlock也是自己封装的一个数据类型,代码:
- public class CWRUBlock {
-
- public final int size;
- public final CWRUData source;
- public double[] DE;
- public double[] FE;
- public double[] BA;
-
- public CWRUBlock(int size, CWRUData source) {
- this.size = size;
- this.source = source;
- }
- }
全都处理完之后就可以做dl4j的DataSet了。代码:
- public static DataSet genGenericDataSet1() {
- List<CWRUBlock> blocks = new ArrayList();
- for (var data: dataList) {
- if (data.pos != 0 && data.pos != 6) continue;
- if (data.depth == 28) continue;
- data.blocks.forEach(blocks::add);
- }
- Collections.shuffle(blocks);
- INDArray[] input = blocks.stream().map(b->Nd4j.create(b.DE, 1, b.DE.length)).toArray(INDArray[]::new);
- INDArray inputs = Nd4j.vstack(input);
- INDArray[] output = blocks.stream().map(b->genOutputFromType(b.source.type())).toArray(INDArray[]::new);
- INDArray outputs = Nd4j.vstack(output);
- DataSet dataSet = new DataSet(inputs, outputs);
- return dataSet;
- }
这里为什么不用dl4j自带的shuffle,而要使用Collections.shuffle呢?这是因为我也不知道为啥dl4j自带的shuffle会直接崩掉jvm。。真是神奇的框架捏。
网络结构(一个非常简单的多层感知机):
- public static MultiLayerConfiguration CWRUANN() {
- MultiLayerConfiguration builder = new NeuralNetConfiguration.Builder()
- .seed(19260817L)
- .updater(new Sgd(0.01))
- .weightInit(WeightInit.XAVIER)
- .list()
- .layer(new DenseLayer.Builder().nIn(512).nOut(128)
- .activation(Activation.RELU)
- .build())
- .layer(new DenseLayer.Builder().nIn(128).nOut(32)
- .activation(Activation.RELU)
- .build())
- .layer(new DenseLayer.Builder().nIn(32).nOut(10)
- .activation(Activation.RELU)
- .build())
- .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
- .activation(Activation.SOFTMAX)
- .nIn(10).nOut(10).build())
- .build();
- return builder;
- }
开启监视器,以及训练:
- public static void ANN(DataSet train, DataSet test) throws Exception {
-
- MultiLayerNetwork model = new MultiLayerNetwork(NetFactory.CWRUANN());
- model.init();
-
- UIServer server = UIServer.getInstance();
- server.enableRemoteListener();
- StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
- model.setListeners(new StatsListener(remoteUIRouter));
-
- DataSetIterator iterator = getIter(train, 20);
-
- for (int x = 0; x< 10000; x++) {
- if (!iterator.hasNext()) {
- iterator = getIter(train, 20);
- }
- model.fit(iterator);
- if (x % 10 == 0) {
- model.save(new File("模型保存路径" + x + ".zip"), true);
- Evaluation eval = new Evaluation(10);
- INDArray output = model.output(test.getFeatures());
- eval.eval(test.getLabels(), output);
- log.info(eval.stats());
- }
- }
- }
-
- private static DataSetIterator getIter(final DataSet set, final int batchSize) {
- final List<DataSet> list = set.asList();
- Collections.shuffle(list, new Random());
- return new ListDataSetIterator(list,batchSize);
- }
训练结果:
- ========================Evaluation Metrics========================
- # of classes: 10
- Accuracy: 0.9831
- Precision: 0.9833
- Recall: 0.9834
- F1 Score: 0.9833
- Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
-
-
- =========================Confusion Matrix=========================
- 0 1 2 3 4 5 6 7 8 9
- -----------------------------------------
- 157 0 0 0 0 0 0 0 0 0 | 0 = 0
- 0 141 0 4 0 0 0 0 0 0 | 1 = 1
- 1 0 177 0 1 2 0 0 0 0 | 2 = 2
- 0 1 1 164 0 2 0 0 1 0 | 3 = 3
- 0 0 1 0 163 1 1 0 0 0 | 4 = 4
- 0 4 0 1 1 172 0 0 0 0 | 5 = 5
- 0 0 0 0 0 0 145 0 0 0 | 6 = 6
- 0 0 0 0 0 0 0 159 0 1 | 7 = 7
- 0 0 0 0 0 0 0 0 161 0 | 8 = 8
- 0 0 0 0 0 0 0 4 0 133 | 9 = 9
-
- Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
- ==================================================================
可以看出效果是非常不错的,真是简单的数据集捏。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。