当前位置:   article > 正文

Hanlp之文本分类_com.hankcs.hanlp.utility.testutility

com.hankcs.hanlp.utility.testutility

1、语料库格式

  1. 分类语料的根目录.目录必须满足如下结构:<br>
  2. 根目录<br>
  3. ├── 分类A<br>
  4. │ └── 1.txt<br>
  5. │ └── 2.txt<br>
  6. │ └── 3.txt<br>
  7. ├── 分类B<br>
  8. │ └── 1.txt<br>
  9. │ └── ...<br>
  10. └── ...<br>
  11. 文件不一定需要用数字命名,也不需要以txt作为后缀名,但一定需要是文本文件.

2、项目格式

训练分类语料库要放到data/test/ 目录下

3、代码

(1)TestUtility类

  1. /*
  2. * <author>Han He</author>
  3. * <email>me@hankcs.com</email>
  4. * <create-date>2018-06-23 11:05 PM</create-date>
  5. *
  6. * <copyright file="TestUtility.java">
  7. * Copyright (c) 2018, Han He. All Rights Reserved, http://www.hankcs.com/
  8. * This source is subject to Han He. Please contact Han He for more information.
  9. * </copyright>
  10. */
  11. package com.cn.test.TextClassification;
  12. import com.hankcs.hanlp.HanLP;
  13. import java.io.*;
  14. import java.net.HttpURLConnection;
  15. import java.net.URL;
  16. import java.util.zip.ZipEntry;
  17. import java.util.zip.ZipInputStream;
  18. /**
  19. * @author hankcs
  20. */
  21. public class TestUtility
  22. {
  23. static
  24. {
  25. ensureFullData();
  26. }
  27. public static void ensureFullData()
  28. {
  29. ensureData(HanLP.Config.PerceptronCWSModelPath, "http://nlp.hankcs.com/download.php?file=data", HanLP.Config.PerceptronCWSModelPath.split("data")[0], false);
  30. }
  31. /**
  32. * 保证 name 存在,不存在时自动下载解压
  33. *
  34. * @param name 路径
  35. * @param url 下载地址
  36. * @return name的绝对路径
  37. */
  38. public static String ensureData(String name, String url)
  39. {
  40. return ensureData(name, url, null, true);
  41. }
  42. /**
  43. * 保证 name 存在,不存在时自动下载解压
  44. *
  45. * @param name 路径
  46. * @param url 下载地址
  47. * @return name的绝对路径
  48. */
  49. public static String ensureData(String name, String url, String parentPath, boolean overwrite)
  50. {
  51. File target = new File(name);
  52. if (target.exists()) return target.getAbsolutePath();
  53. try
  54. {
  55. File parentFile = parentPath == null ? new File(name).getParentFile() : new File(parentPath);
  56. if (!parentFile.exists()) parentFile.mkdirs();
  57. String filePath = downloadFile(url, parentFile.getAbsolutePath());
  58. if (filePath.endsWith(".zip"))
  59. {
  60. unzip(filePath, parentFile.getAbsolutePath(), overwrite);
  61. }
  62. return target.getAbsolutePath();
  63. }
  64. catch (Exception e)
  65. {
  66. System.err.printf("数据下载失败,请尝试手动下载 %s 到 %s 。原因如下:\n", url, target.getAbsolutePath());
  67. e.printStackTrace();
  68. System.exit(1);
  69. return null;
  70. }
  71. }
  72. /**
  73. * 保证 data/test/name 存在
  74. *
  75. * @param name
  76. * @param url
  77. * @return
  78. */
  79. public static String ensureTestData(String name, String url)
  80. {
  81. return ensureData(String.format("data/test/%s", name), url);
  82. }
  83. /**
  84. * Downloads a file from a URL
  85. *
  86. * @param fileURL HTTP URL of the file to be downloaded
  87. * @param savePath path of the directory to save the file
  88. * @throws IOException
  89. * @author www.codejava.net
  90. */
  91. public static String downloadFile(String fileURL, String savePath)
  92. throws IOException
  93. {
  94. System.err.printf("Downloading %s to %s\n", fileURL, savePath);
  95. HttpURLConnection httpConn = request(fileURL);
  96. while (httpConn.getResponseCode() == HttpURLConnection.HTTP_MOVED_PERM || httpConn.getResponseCode() == HttpURLConnection.HTTP_MOVED_TEMP)
  97. {
  98. httpConn = request(httpConn.getHeaderField("Location"));
  99. }
  100. // always check HTTP response code first
  101. if (httpConn.getResponseCode() == HttpURLConnection.HTTP_OK)
  102. {
  103. String fileName = "";
  104. String disposition = httpConn.getHeaderField("Content-Disposition");
  105. String contentType = httpConn.getContentType();
  106. int contentLength = httpConn.getContentLength();
  107. if (disposition != null)
  108. {
  109. // extracts file name from header field
  110. int index = disposition.indexOf("filename=");
  111. if (index > 0)
  112. {
  113. fileName = disposition.substring(index + 10,
  114. disposition.length() - 1);
  115. }
  116. }
  117. else
  118. {
  119. // extracts file name from URL
  120. fileName = new File(httpConn.getURL().getPath()).getName();
  121. }
  122. // System.out.println("Content-Type = " + contentType);
  123. // System.out.println("Content-Disposition = " + disposition);
  124. // System.out.println("Content-Length = " + contentLength);
  125. // System.out.println("fileName = " + fileName);
  126. // opens input stream from the HTTP connection
  127. InputStream inputStream = httpConn.getInputStream();
  128. String saveFilePath = savePath;
  129. if (new File(savePath).isDirectory())
  130. saveFilePath = savePath + File.separator + fileName;
  131. String realPath;
  132. if (new File(saveFilePath).isFile())
  133. {
  134. System.err.printf("Use cached %s instead.\n", fileName);
  135. realPath = saveFilePath;
  136. }
  137. else
  138. {
  139. saveFilePath += ".downloading";
  140. // opens an output stream to save into file
  141. FileOutputStream outputStream = new FileOutputStream(saveFilePath);
  142. int bytesRead;
  143. byte[] buffer = new byte[4096];
  144. long start = System.currentTimeMillis();
  145. int progress_size = 0;
  146. while ((bytesRead = inputStream.read(buffer)) != -1)
  147. {
  148. outputStream.write(buffer, 0, bytesRead);
  149. long duration = (System.currentTimeMillis() - start) / 1000;
  150. duration = Math.max(duration, 1);
  151. progress_size += bytesRead;
  152. int speed = (int) (progress_size / (1024 * duration));
  153. float ratio = progress_size / (float) contentLength;
  154. float percent = ratio * 100;
  155. int eta = (int) (duration / ratio * (1 - ratio));
  156. int minutes = eta / 60;
  157. int seconds = eta % 60;
  158. System.err.printf("\r%.2f%%, %d MB, %d KB/s, ETA %d min %d s", percent, progress_size / (1024 * 1024), speed, minutes, seconds);
  159. }
  160. System.err.println();
  161. outputStream.close();
  162. realPath = saveFilePath.substring(0, saveFilePath.length() - ".downloading".length());
  163. if (!new File(saveFilePath).renameTo(new File(realPath)))
  164. throw new IOException("Failed to move file");
  165. }
  166. inputStream.close();
  167. httpConn.disconnect();
  168. return realPath;
  169. }
  170. else
  171. {
  172. httpConn.disconnect();
  173. throw new IOException("No file to download. Server replied HTTP code: " + httpConn.getResponseCode());
  174. }
  175. }
  176. private static HttpURLConnection request(String url) throws IOException
  177. {
  178. HttpURLConnection httpConn = (HttpURLConnection) new URL(url).openConnection();
  179. httpConn.setRequestProperty("User-Agent", "Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.4; en-US; rv:1.9.2.2) Gecko/20100316 Firefox/3.6.2");
  180. return httpConn;
  181. }
  182. private static void unzip(String zipFilePath, String destDir, boolean overwrite)
  183. {
  184. System.err.println("Unzipping to " + destDir);
  185. File dir = new File(destDir);
  186. // create output directory if it doesn't exist
  187. if (!dir.exists()) dir.mkdirs();
  188. FileInputStream fis;
  189. //buffer for read and write data to file
  190. byte[] buffer = new byte[4096];
  191. try
  192. {
  193. fis = new FileInputStream(zipFilePath);
  194. ZipInputStream zis = new ZipInputStream(fis);
  195. ZipEntry ze = zis.getNextEntry();
  196. while (ze != null)
  197. {
  198. String fileName = ze.getName();
  199. File newFile = new File(destDir + File.separator + fileName);
  200. if (overwrite || !newFile.exists())
  201. {
  202. if (ze.isDirectory())
  203. {
  204. //create directories for sub directories in zip
  205. newFile.mkdirs();
  206. }
  207. else
  208. {
  209. new File(newFile.getParent()).mkdirs();
  210. FileOutputStream fos = new FileOutputStream(newFile);
  211. int len;
  212. while ((len = zis.read(buffer)) > 0)
  213. {
  214. fos.write(buffer, 0, len);
  215. }
  216. fos.close();
  217. //close this ZipEntry
  218. zis.closeEntry();
  219. }
  220. }
  221. ze = zis.getNextEntry();
  222. }
  223. //close last ZipEntry
  224. zis.closeEntry();
  225. zis.close();
  226. fis.close();
  227. new File(zipFilePath).delete();
  228. }
  229. catch (IOException e)
  230. {
  231. e.printStackTrace();
  232. }
  233. }
  234. }

(2)ModelTrain类

  1. package com.cn.test.TextClassification;
  2. import com.hankcs.hanlp.classification.classifiers.IClassifier;
  3. import com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier;
  4. import com.hankcs.hanlp.classification.models.NaiveBayesModel;
  5. import com.hankcs.hanlp.corpus.io.IOUtil;
  6. import java.io.File;
  7. public class ModelTrain {
  8. /**
  9. * 搜狗文本分类语料库5个类目,每个类目下1000篇文章,共计5000篇文章
  10. */
  11. public static final String CORPUS_FOLDER = TestUtility.ensureTestData("搜狗文本分类语料库迷你版", "");
  12. /**
  13. * 模型保存路径
  14. */
  15. public static final String MODEL_PATH = "data/test/classification-model.ser";
  16. public static NaiveBayesModel trainOrLoadModel()
  17. {
  18. NaiveBayesModel model = (NaiveBayesModel) IOUtil.readObjectFrom(MODEL_PATH);
  19. if (model != null) return model;
  20. File corpusFolder = new File(CORPUS_FOLDER);
  21. if (!corpusFolder.exists() || !corpusFolder.isDirectory())
  22. {
  23. System.err.println("没有文本分类语料!");
  24. System.exit(1);
  25. }
  26. try{
  27. IClassifier classifier = new NaiveBayesClassifier(); // 创建分类器,更高级的功能请参考IClassifier的接口定义
  28. classifier.train(CORPUS_FOLDER); // 训练后的模型支持持久化,下次就不必训练了
  29. model = (NaiveBayesModel) classifier.getModel();
  30. IOUtil.saveObjectTo(model, MODEL_PATH);
  31. }catch (Exception e){
  32. e.printStackTrace();
  33. }
  34. return model;
  35. }
  36. }

(3)InitOneObject类

  1. package com.cn.test.TextClassification;
  2. import com.hankcs.hanlp.classification.classifiers.IClassifier;
  3. import com.hankcs.hanlp.classification.classifiers.NaiveBayesClassifier;
  4. public class InitOneObject {
  5. public static final InitOneObject instance= new InitOneObject();
  6. //获取classifier对象,训练后的模型支持持久化,下次就不必训练了.
  7. public IClassifier classifier = new NaiveBayesClassifier(ModelTrain.trainOrLoadModel());
  8. }

(4)TrainMain类

  1. package com.cn.test.TextClassification;
  2. import com.hankcs.hanlp.classification.classifiers.IClassifier;
  3. import java.io.IOException;
  4. public class TrainMain {
  5. public static void main(String[] args) throws IOException
  6. {
  7. //IClassifier classifier = new NaiveBayesClassifier(trainOrLoadModel());
  8. IClassifier classifier= InitOneObject.instance.classifier;
  9. predict(classifier, "C罗获2018环球足球奖最佳球员 德尚荣膺最佳教练");
  10. predict(classifier, "英国造航母耗时8年仍未服役 被中国速度远远甩在身后");
  11. predict(classifier, "研究生考录模式亟待进一步专业化");
  12. predict(classifier, "如果真想用食物解压,建议可以食用燕麦");
  13. predict(classifier, "锄禾日当午,汗滴禾下土");
  14. }
  15. private static void predict(IClassifier classifier, String text)
  16. {
  17. System.out.printf("《%s》 属于分类 【%s】\n", text, classifier.classify(text));
  18. }
  19. }

4、运行结果

  1. 模式:训练集
  2. 文本编码:UTF-8
  3. 根目录:C:\MyselfApplication\MyProject\HanLp\data\test\搜狗文本分类语料库迷你版
  4. 加载中...
  5. [体育]...100.00% 1000 篇文档
  6. [健康]...100.00% 1000 篇文档
  7. [军事]...100.00% 1000 篇文档
  8. [教育]...100.00% 1000 篇文档
  9. [汽车]...100.00% 1000 篇文档
  10. 耗时 15868 ms 加载了 5 个类目,共 5000 篇文档
  11. 原始数据集大小:5000
  12. 使用卡方检测选择特征中...耗时 189 ms,选中特征数:18156 / 80986 = 22.42%
  13. 贝叶斯统计结束
  14. 《C罗获2018环球足球奖最佳球员 德尚荣膺最佳教练》 属于分类 【体育】
  15. 《英国造航母耗时8年仍未服役 被中国速度远远甩在身后》 属于分类 【军事】
  16. 《研究生考录模式亟待进一步专业化》 属于分类 【教育】
  17. 《如果真想用食物解压,建议可以食用燕麦》 属于分类 【健康】
  18. 《锄禾日当午,汗滴禾下土》 属于分类 【健康】

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/953970
推荐阅读
相关标签
  

闽ICP备14008679号