当前位置:   article > 正文

卷积文本分类(gpu)实现--deeplearning4j_nd4j.create

nd4j.create

前面介绍用卷积训练文本分类模型,但是算法是cpu上跑的,涉及到大数据,cpu上是跑不动的,代码在之前的博客里面可以看到,本博客主要记录在gpu上跑碰到的坑。


gpu版本信息:

root@image-ubuntu:~# nvidia-smi 

Fri Jul 14 01:21:46 2017       

+-----------------------------------------------------------------------------+

| NVIDIA-SMI 375.51                 Driver Version: 375.51                    |

|-------------------------------+----------------------+----------------------+

| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |

| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |

|===============================+======================+======================|

|   0  Tesla M60           Off  | 0000:00:02.0     Off |                  Off |

| N/A   52C    P0    46W / 150W |   3448MiB /  8123MiB |     21%      Default |

+-------------------------------+----------------------+----------------------+

                                                                               

+-----------------------------------------------------------------------------+

| Processes:                                                       GPU Memory |

|  GPU       PID  Type  Process name                               Usage      |

|=============================================================================|

|    0     53395    C   java                                          3438MiB |

+-----------------------------------------------------------------------------+



训练过程:



01:17:19.580 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6942802230659779
01:17:24.036 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6964763564002254
01:17:28.767 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6825513419103831
01:17:33.190 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6779352336198492
01:17:37.119 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6845303732693183
01:17:41.313 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6806721112942656
01:17:44.184 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 200 is 0.6814421662117872
01:17:45.113 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6701808819009931
01:17:49.053 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6650278905527942
01:17:53.505 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6565601670736454
01:17:58.273 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6697584503102176
01:18:02.572 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6510347362144552
01:18:06.458 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6611058336565505
01:18:10.226 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6541094549663357
01:18:13.547 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6428940961803716
01:18:17.627 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6382708927005436
01:18:18.026 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 300 is 0.6395620244327883
01:18:22.073 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6311690317350285
01:18:25.798 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6441202013287363
01:18:29.440 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6297861390295019
01:18:33.594 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6392450155730185
01:18:38.271 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6228116943748379
01:18:42.867 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6184209858527969
01:18:46.715 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6133259157463684
01:18:50.564 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6016078467137244
01:18:52.438 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 400 is 0.6253361305693586
01:18:54.943 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6150038531894072
01:18:58.971 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6179796273714999
01:19:03.511 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5815389255352973
01:19:07.499 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6121308310206943
01:19:11.571 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5908273267756271
01:19:16.073 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.6028526854103197
01:19:20.606 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5953763587640233
01:19:24.801 [main] INFO  o.d.parallelism.ParallelWrapper - Averaged score: 0.5793217319992398
01:19:27.736 [ParallelWrapper trainer 0] INFO  o.d.o.l.ScoreIterationListener - Score at iteration 500


报错如下:

   
   
  1. Exception in thread "main" java.lang.UnsupportedClassVersionError: org/deeplearning4j/parallelism/ParallelWrapper$Builder : Unsupported major.minor version 52.0
  2. at java.lang.ClassLoader.defineClass1(Native Method)
  3. at java.lang.ClassLoader.defineClass(ClassLoader.java:800)
  4. at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142)
  5. at java.net.URLClassLoader.defineClass(URLClassLoader.java:449)
  6. at java.net.URLClassLoader.access$100(URLClassLoader.java:71)
  7. at java.net.URLClassLoader$1.run(URLClassLoader.java:361)
  8. at java.net.URLClassLoader$1.run(URLClassLoader.java:355)
  9. at java.security.AccessController.doPrivileged(Native Method)
  10. at java.net.URLClassLoader.findClass(URLClassLoader.java:354)
  11. at java.lang.ClassLoader.loadClass(ClassLoader.java:425)
  12. at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:308)
  13. at java.lang.ClassLoader.loadClass(ClassLoader.java:358)
  14. at com.dianping.deeplearning.test.TestWithGPU.main(TestWithGPU.java:115)

主要是因为jdk的版本过低的原因,解决方案:

   
   
  1. 把java换成1.8的版本


报错如下:
    
    
  1. INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
  2. 03:57:22.643 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
  3. 03:57:28.993 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
  4. 03:57:35.097 [main] INFO o.d.parallelism.ParallelWrapper - Averaged score: NaN
主要是用于gpu训练时候精度损失的问题,解决如下:
     
     
  1. DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);


报错如下:
     
     
  1. Exception in thread "main" java.lang.RuntimeException: Exception thrown in base iterator
  2. at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator.next(AsyncDataSetIterator.java:247)
  3. at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator.next(AsyncDataSetIterator.java:33)
  4. at org.deeplearning4j.parallelism.ParallelWrapper.fit(ParallelWrapper.java:379)
  5. at com.dianping.deeplearning.cnn.TrainAdxCnnModelWithGPU.main(TrainAdxCnnModelWithGPU.java:170)
  6. Caused by: org.nd4j.linalg.exception.ND4JIllegalStateException: Invalid shape: Requested INDArray shape [1, 0] contains dimension size values < 1 (all dimensions must be 1 or more)
  7. at org.nd4j.linalg.factory.Nd4j.checkShapeValues(Nd4j.java:4776)
  8. at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:3997)
  9. at org.nd4j.linalg.api.ndarray.BaseNDArray.create(BaseNDArray.java:1906)
  10. at org.nd4j.linalg.api.ndarray.BaseNDArray.subArray(BaseNDArray.java:2064)
  11. at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4015)
  12. at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:222)
  13. at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:155)
  14. at com.dianping.deeplearning.cnn.CnnSentenceDataSetIterator.next(CnnSentenceDataSetIterator.java:25)
  15. at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator$IteratorRunnable.run(AsyncDataSetIterator.java:322)

把featuresMask 设置为null既可以





最后附上训练gpu的代码:

  1. package com.dianping.deeplearning.cnn;
  2. import java.io.File;
  3. import java.io.FileInputStream;
  4. import java.io.FileNotFoundException;
  5. import java.io.FileOutputStream;
  6. import java.io.IOException;
  7. import java.io.ObjectInputStream;
  8. import java.io.ObjectOutputStream;
  9. import java.io.UnsupportedEncodingException;
  10. import java.util.List;
  11. import java.util.Random;
  12. import org.deeplearning4j.eval.Evaluation;
  13. import org.deeplearning4j.iterator.LabeledSentenceProvider;
  14. import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
  15. import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
  16. import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
  17. import org.deeplearning4j.nn.conf.ConvolutionMode;
  18. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  19. import org.deeplearning4j.nn.conf.Updater;
  20. import org.deeplearning4j.nn.conf.graph.MergeVertex;
  21. import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
  22. import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
  23. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  24. import org.deeplearning4j.nn.conf.layers.PoolingType;
  25. import org.deeplearning4j.nn.graph.ComputationGraph;
  26. import org.deeplearning4j.nn.weights.WeightInit;
  27. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  28. import org.deeplearning4j.parallelism.ParallelWrapper;
  29. import org.nd4j.jita.conf.CudaEnvironment;
  30. import org.nd4j.linalg.activations.Activation;
  31. import org.nd4j.linalg.api.buffer.DataBuffer;
  32. import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
  33. import org.nd4j.linalg.api.ndarray.INDArray;
  34. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  35. import org.nd4j.linalg.lossfunctions.LossFunctions;
  36. public class TrainAdxCnnModelWithGPU {
  37. public static void main(String[] args) throws FileNotFoundException,
  38. UnsupportedEncodingException {
  39. /*
  40. * gpu训练设置
  41. */
  42. System.out.println("。。。。。。。gpu初始化即将开始。。。。。。。。。");
  43. // PLEASE NOTE: For CUDA FP16 precision support is available
  44. DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
  45. // temp workaround for backend initialization
  46. CudaEnvironment.getInstance().getConfiguration()
  47. // key option enabled
  48. .allowMultiGPU(true)
  49. // we're allowing larger memory caches
  50. .setMaximumDeviceCache(2L * 1024L * 1024L * 1024L)
  51. // cross-device access is used for faster model averaging over pcie
  52. .allowCrossDeviceAccess(true);
  53. System.out.println("。。。。。。。。。gpu初始化即将结束。。。。。。。。。。");
  54. String WORD_VECTORS_PATH = "/home/zhoumeixu/model/word2vec.model";
  55. // 基础配置
  56. int batchSize = 128;
  57. int vectorSize = 15; // 词典向量的维度,这边是100
  58. int nEpochs = 15000; // 重复多少次
  59. int iterator = 1;// 迭代多少次
  60. int truncateReviewsToLength = 256; // 词长大于256则抛弃
  61. int cnnLayerFeatureMaps = 100; // 卷积神经网络特征图标 / channels / CNN每层layer的深度
  62. PoolingType globalPoolingType = PoolingType.MAX;
  63. Random rng = new Random(100); // 随机抽样
  64. // 设置网络配置->我们有多个卷积层,每个带宽3,4,5的滤波器
  65. ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
  66. .weightInit(WeightInit.RELU)
  67. .activation(Activation.LEAKYRELU)
  68. .updater(Updater.ADAM)
  69. .convolutionMode(ConvolutionMode.Same)
  70. // This is important so we can 'stack' the results later
  71. .regularization(true)
  72. .l2(0.0001)
  73. .iterations(iterator)
  74. .learningRate(0.01)
  75. .graphBuilder()
  76. .addInputs("input")
  77. .addLayer(
  78. "cnn3",
  79. new ConvolutionLayer.Builder()
  80. .kernelSize(3, vectorSize)
  81. .stride(1, vectorSize).nIn(1)
  82. .nOut(cnnLayerFeatureMaps).build(), "input")
  83. .addLayer(
  84. "cnn4",
  85. new ConvolutionLayer.Builder()
  86. .kernelSize(4, vectorSize)
  87. .stride(1, vectorSize).nIn(1)
  88. .nOut(cnnLayerFeatureMaps).build(), "input")
  89. .addLayer(
  90. "cnn5",
  91. new ConvolutionLayer.Builder()
  92. .kernelSize(5, vectorSize)
  93. .stride(1, vectorSize).nIn(1)
  94. .nOut(cnnLayerFeatureMaps).build(), "input")
  95. .addVertex("merge", new MergeVertex(), "cnn3", "cnn4", "cnn5")
  96. // Perform depth concatenation
  97. .addLayer(
  98. "globalPool",
  99. new GlobalPoolingLayer.Builder().poolingType(
  100. globalPoolingType).build(), "merge")
  101. .addLayer(
  102. "out",
  103. new OutputLayer.Builder()
  104. .lossFunction(LossFunctions.LossFunction.MCXENT)
  105. .activation(Activation.SOFTMAX)
  106. .nIn(3 * cnnLayerFeatureMaps).nOut(2).build(),
  107. "globalPool").setOutputs("out").build();
  108. ComputationGraph net = new ComputationGraph(config);
  109. net.init();
  110. // ParallelWrapper will take care of load balancing between GPUs.
  111. ParallelWrapper wrapper = new ParallelWrapper.Builder(net)
  112. // DataSets prefetching options. Set this value with respect to number of actual devices
  113. .prefetchBuffer(24)
  114. // set number of workers equal or higher then number of available devices. x1-x2 are good values to start with
  115. .workers(4)
  116. // rare averaging improves performance, but might reduce model accuracy
  117. .averagingFrequency(3)
  118. // if set to TRUE, on every averaging model score will be reported
  119. .reportScoreAfterAveraging(true)
  120. // optinal parameter, set to false ONLY if your system has support P2P memory access across PCIe (hint: AWS do not support P2P)
  121. .useLegacyAveraging(true)
  122. .build();
  123. net.setListeners(new ScoreIterationListener(100));
  124. // 加载向量字典并获取训练集合测试集的DataSetIterators
  125. System.out
  126. .println("Loading word vectors and creating DataSetIterators");
  127. /*
  128. * WordVectors wordVectors = WordVectorSerializer
  129. * .fromPair(WordVectorSerializer.loadTxt(new File(
  130. * WORD_VECTORS_PATH)));
  131. */
  132. WordVectors wordVectors = WordVectorSerializer
  133. .readWord2VecModel(WORD_VECTORS_PATH);
  134. DataSetIterator trainIter = getDataSetIterator(true, wordVectors,
  135. batchSize, truncateReviewsToLength, rng);
  136. DataSetIterator testIter = getDataSetIterator(false, wordVectors,
  137. batchSize, truncateReviewsToLength, rng);
  138. System.out.println("Starting training");
  139. for (int i = 0; i < nEpochs; i++) {
  140. wrapper.fit(trainIter);
  141. trainIter.reset();
  142. // 进行网络演化(进化)获得网络判定参数
  143. Evaluation evaluation = net.evaluate(testIter);
  144. testIter.reset();
  145. System.out.println(evaluation.stats());
  146. System.out.println("。。。。。。。第"+i+"。。。。。。。。步已经完成。。。。。。。。。。");
  147. }
  148. /*
  149. * 保存模型
  150. */
  151. saveNet("/home/zhoumeixu/model/cnn.model", net);
  152. /*
  153. * 加载模型
  154. */
  155. ComputationGraph netload = loadNet("/home/zhoumeixu/model/cnn.model");
  156. // 训练之后:加载一个句子并输出预测
  157. String contentsFirstPas = "我的 手机 是 手机号码";
  158. INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator) testIter)
  159. .loadSingleSentence(contentsFirstPas);
  160. INDArray predictionsFirstNegative = netload
  161. .outputSingle(featuresFirstNegative);
  162. List<String> labels = testIter.getLabels();
  163. System.out.println("\n\nPredictions for first negative review:");
  164. for (int i = 0; i < labels.size(); i++) {
  165. System.out.println("P(" + labels.get(i) + ") = "
  166. + predictionsFirstNegative.getDouble(i));
  167. }
  168. }
  169. private static DataSetIterator getDataSetIterator(boolean isTraining,
  170. WordVectors wordVectors, int minibatchSize, int maxSentenceLength,
  171. Random rng) {
  172. String path = isTraining ? "/home/zhoumeixu/model/rnnsenec.txt" : "/home/zhoumeixu/model/rnnsenectest.txt";
  173. LabeledSentenceProvider sentenceProvider = new LabeledSentence(path,
  174. rng);
  175. return new CnnSentenceDataSetIterator.Builder()
  176. .sentenceProvider(sentenceProvider).wordVectors(wordVectors)
  177. .minibatchSize(minibatchSize)
  178. .maxSentenceLength(maxSentenceLength)
  179. .useNormalizedWordVectors(false).build();
  180. }
  181. public static void saveNet(String path, ComputationGraph net) {
  182. ObjectOutputStream objectOutputStream = null;
  183. try {
  184. objectOutputStream = new ObjectOutputStream(new FileOutputStream(
  185. path));
  186. objectOutputStream.writeObject(net);
  187. objectOutputStream.close();
  188. } catch (Exception e) {
  189. e.printStackTrace();
  190. }
  191. }
  192. public static ComputationGraph loadNet(String path) {
  193. ObjectInputStream objectInputStream = null;
  194. ComputationGraph net = null;
  195. try {
  196. objectInputStream = new ObjectInputStream(new FileInputStream(path));
  197. net = (ComputationGraph) objectInputStream.readObject();
  198. objectInputStream.close();
  199. } catch (Exception e) {
  200. }
  201. return net;
  202. }
  203. }



声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/356267
推荐阅读
  

闽ICP备14008679号