赞
踩
ROUGE (Recall-Oriented Understudy for Gisting Evaluation) 是用于评估自动文摘或机器翻译的一种评估方法,其中的ROUGE-L指标是基于最长公共子序列(Longest Common Subsequence,LCS)来计算的
我们做AI问答系统,需要一些量化指标来作为优化人工智能大模型的指导标准,经过调查Rouge-L的特征测量是量化指标的手段之一。
为了采用更精确的分词算法、词性还原和停用词处理,我借助一些自然语言处理的库,以下是我引入的maven依赖:
<dependency> <groupId>edu.stanford.nlp</groupId> <artifactId>stanford-corenlp</artifactId> <version>3.9.2</version> </dependency> <dependency> <groupId>edu.stanford.nlp</groupId> <artifactId>stanford-corenlp</artifactId> <version>3.9.2</version> <classifier>models</classifier> </dependency> <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <version>4.12</version> <scope>test</scope> </dependency>
然后是我的具体代码实现,因为词性标注和词性还原需要借助本地模型实现,为了快速落地量化指标,我暂时不使用词性标注和词性还原。
(Recall-Oriented Understudy for Gisting Evaluation) 是用于评估自动文摘或机器翻译的一种评估方法,其中的ROUGE-L指标是基于最长公共子序列(Longest Common Subsequence,LCS)来计算的。它主要关注词序列的匹配程度,不依赖于词性标注和词性还原。
因此,即使不使用词性标注和词性还原,只要确保分词正确,你仍然可以得到有效的ROUGE-L指标。
以下是我的代码具体实现:
package com.xxx.zjtest.testtest.test; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.pipeline.*; import edu.stanford.nlp.util.CoreMap; import java.util.*; /** * * 不使用词性标注和词性还原直接计算ROUGE-L指标本身不会产生问题。 * */ public class RougeLCalculator { private static final Set<String> STOP_WORDS = new HashSet<>(Arrays.asList("的", "了", "和", "是", "就", "都", "而", "及", "与", "在")); // 中文停用词示例列表 public static double calculateRougeL(String referenceText, String hypothesisText) { // 创建Stanford CoreNLP管道 Properties props = new Properties(); props.setProperty("annotators", "tokenize, ssplit"); //加, pos, lemma 会报错 props.setProperty("ssplit.eolonly", "true"); StanfordCoreNLP pipeline = new StanfordCoreNLP(props); // 处理参考答案和假设答案 Annotation referenceAnnotation = new Annotation(referenceText); Annotation hypothesisAnnotation = new Annotation(hypothesisText); pipeline.annotate(referenceAnnotation); pipeline.annotate(hypothesisAnnotation); // 获取参考答案和假设答案的词性还原后的单词列表 /*List<String> referenceWords = getLemmatizedWords(referenceAnnotation, STOP_WORDS); List<String> hypothesisWords = getLemmatizedWords(hypothesisAnnotation, STOP_WORDS);*/ // 获取参考答案和假设答案的分词列表 List<String> referenceWords = getTokenizedWords(referenceAnnotation, STOP_WORDS); List<String> hypothesisWords = getTokenizedWords(hypothesisAnnotation, STOP_WORDS); // 计算ROUGE-L指标 return calculateLongestCommonSubsequence(referenceWords, hypothesisWords); } //词性标注和词性还原部分,此部分需要使用大模型,无法实现 /*private static List<String> getLemmatizedWords(Annotation annotation, Set<String> stopWords) { List<String> words = new ArrayList<>(); for (CoreMap sentence : annotation.get(CoreAnnotations.SentencesAnnotation.class)) { for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) { String word = token.word(); // 获取分词后的单词 String pos = token.get(CoreAnnotations.PartOfSpeechAnnotation.class); // 获取词性标注 String lemma = token.get(CoreAnnotations.LemmaAnnotation.class); // 获取词性还原后的单词 if (!stopWords.contains(word) && !pos.equalsIgnoreCase("PUNCT")) { // 过滤停用词和标点符号 words.add(lemma); } } } return words; }*/ private static List<String> getTokenizedWords(Annotation annotation, Set<String> stopWords) { List<String> words = new ArrayList<>(); for (CoreMap sentence : annotation.get(CoreAnnotations.SentencesAnnotation.class)) { for (CoreLabel token : sentence.get(CoreAnnotations.TokensAnnotation.class)) { String word = token.word(); // 获取分词后的单词 if (!stopWords.contains(word)) { // 过滤停用词 words.add(word); } } } return words; } //计算公共最长子序列 private static double calculateLongestCommonSubsequence(List<String> referenceWords, List<String> hypothesisWords) { int[][] dp = new int[referenceWords.size() + 1][hypothesisWords.size() + 1]; for (int i = 1; i <= referenceWords.size(); i++) { for (int j = 1; j <= hypothesisWords.size(); j++) { if (referenceWords.get(i - 1).equals(hypothesisWords.get(j - 1))) { dp[i][j] = dp[i - 1][j - 1] + 1; } else { dp[i][j] = Math.max(dp[i - 1][j], dp[i][j - 1]); } } } int lcs = dp[referenceWords.size()][hypothesisWords.size()]; double precision = (double) lcs / hypothesisWords.size(); double recall = (double) lcs / referenceWords.size(); double rougeL = 2 * precision * recall / (precision + recall); System.out.println(rougeL); return rougeL; } public static double calculateAverageRougeL(List<String> referenceTexts, List<String> hypothesisTexts) { if (referenceTexts.size() != hypothesisTexts.size()) { throw new IllegalArgumentException("参考文本列表和假设文本列表的长度必须相等"); } double totalRougeL = 0; for (int i = 0; i < referenceTexts.size(); i++) { totalRougeL += calculateRougeL(referenceTexts.get(i), hypothesisTexts.get(i)); } return totalRougeL / referenceTexts.size(); } public static void main(String[] args) { String yTest1 = "客户专用网络的使用应注意以下几点:\n" + "\n" + "实施方案应由公司网络管理员组织制定,并与客户共同协商制定实施方案。\n" + "禁止将客户要求隔离的网络与未经客户许可的网络连通。\n" + "禁止将客户要求使用的网络与公司内部网络连通。\n" + "禁止在客户专用网络中搭建无线网络。\n" + "禁止在客户专用网络中使用笔记本电脑。\n" + "禁止在客户专用网络中未经过客户允许进行互联网访问。\n" + "禁止在客户专用网络中未按照客户要求执行设备要求。"; String nTest1 = "根据已知信息,客户专用网络的注意事项有以下几点:\n" + "\n" + "禁止将客户要求隔离的网络与未经客户许可的网络连通。\n" + "禁止将客户要求使用的网络与公司内部网络连通。\n" + "禁止在客户专用网络中搭建无线网络。\n" + "禁止在客户专用网络中使用笔记本电脑。\n" + "禁止在客户专用网络中未经过客户允许进行互联网访问。\n" + "禁止在客户专用网络中未按照客户要求执行设备要求。\n" + "客户专用网络的申请与构建应由公司网络管理员组织制定,并与客户共同协商制定实施方案。\n" + "客户专用网络的规划与实施应由公司网络管理员与申请部门共同执行,如在实施过程中需要第三方参与,公司网络管理员应进行监督,并与申请部门共同验收。\n" + "若需使用客户的内部网络,必须事先获得客户批准,并严格遵守客户的网络安全规定,只能从事和业务相关的工作,不得私自查看其它计算机或客户内部网络上的任何保密信息,禁止任何令客户网络严重增加负载或容易引起网络故障的行为。\n" + "若需在客户现场访问Internet,必须事先获得客户批准,并按照客户的要求采取必要的安全措施。"; String yTest2 = "禁止将公司资料、私人信息以及敏感内容存储在个人邮箱中,避免信息泄露。\n" + "及时删除接收的邮件,避免信息被他人获取。\n" + "禁止使用私人邮箱用于工作中的业务往来,避免信息泄露。\n" + "谨慎使用电子邮件,避免感染病毒,确保邮件发送保密信息。\n" + "禁止使用工作手机号码进行非工作目的的网站注册。\n" + "及时将工作手机的操作系统升级至最新版本。\n" + "工作手机信息安全管理要求,包括保密性、完整性、可用性等方面。"; String nTest2="不要将公司资料、私人信息以及敏感内容存储在个人邮箱中,避免信息泄露。\n" + "及时删除接收的邮件,避免信息被他人获取。\n" + "禁止使用私人邮箱用于工作中的业务往来。\n" + "谨慎使用电子邮件,避免感染病毒,确保邮件发送保密信息。\n" + "禁止使用工作手机号码进行非工作目的的网站注册。\n" + "及时将工作手机的操作系统升级至最新版本。\n" + "遵守公司电子邮件尺寸规定,不得发送超大邮件。\n" + "收到不明电子邮件尽量不要回信,含有可疑附件时不得打开,并应立即删除该邮件。\n" + "发送电子邮件时,应认真核对收件人地址,避免误传送。\n" + "若发生电子邮件误传送时,应立即再发一封致歉信,声明发错,并请对方将已收到的邮件删除。\n" + "若邮件含有公司敏感信息,应立即向部门领导及业务安全部汇报。\n" + "当收到别人发错的电子邮件时,应立即通知提醒发件人,若该邮件含有改善敏感信息,应将邮件彻底删除。\n" + "禁止使用手机VPN进行加密。\n" + "禁止使用工作手机接收邮件。\n" + "禁止使用私人邮箱进行邮件收发。\n" + "禁止使用手机接收或发送邮件时点击或打开可疑链接或附件,避免被恶意攻击导致信息泄露。"; //计算单个值 //System.out.println(calculateRougeL(yTest1,nTest1)); // 计算平均值 List<String> referenceTexts = Arrays.asList(yTest1, yTest2); List<String> hypothesisTexts = Arrays.asList(nTest1, nTest2); double averageRougeL = calculateAverageRougeL(referenceTexts, hypothesisTexts); System.out.println("平均ROUGE-L值: " + averageRougeL); } }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。