当前位置:   article > 正文

Elasticsearch 混合检索优化大模型 RAG 任务_基于倒数的融合排序

基于倒数的融合排序

Elastic 社区在自然语言处理上面做的很不错官方博客更新速度也很快,现阶段大模型的应用场景主要在 Rag 和 Agent 上,国内 Rag(Retrieval-Augmented Generation 检索增强生成) 的尤其多,而搜索对于 Elasticsearch 来说是强项特别是 8.9 之后的版本提供了 ESRE 模块(集成了高级相关性排序如 BM25f、强大的矢量数据库、自然语言处理技术、与第三方模型如 GPT-3 和 GPT-4 的集成,并支持开发者自定义模型与应用),经过我的各种尝试在 Elasticsearch 上做 NLP 是一个很不错的选择,要做大规模的 RAG 任务甚至是针对图像、声音、多模态、关键词等大数据量的向量召回且搭配生成式模型这种复杂的业务场景 Elasticsearch 是天生支持的。此篇文章主要记录混合检索(BM25 +HNSW)倒数融合排序(RRF)完整测试。

官博有几篇不错的文章可以看看:

先说一下 RAG 任务的流程,以民法典为例 LLM 可以在现有资料上分析出确切的回答:
文档分割 -> 文本向量化 -> 问句向量化 -> 向量相似 top k个 -> 拼接 prompt 上下文  -> 提交给 LLM 生成回答。

1.混合检索

全文检索 + ANN 检索。因为全文检索能查找更加准确的文档,直观都会感觉比单一的相似度检索更强。一个混合检索的查询语句例如:

  1. {
  2.   "query": {
  3.     "bool": {
  4.       "must": [
  5.         { "match": {"content": {"query": "结婚领证登记需要双发到场吗?","boost": 1}}}
  6.       ]
  7.     }
  8.   },
  9.   "knn": {
  10.     "field": "content_embed",
  11.     "k": 5,
  12.     "num_candidates": 100,
  13.     "query_vector": []   // 向量、省略
  14.   },
  15.   "size": 5
  16. }

2.倒数融合排序

倒数排序融合 - Reciprocal rank fusion:
由于全文搜索及向量搜索是使用不同的算法进行打分的,这就造成把两个不同搜索结果综合起来统一排名的困难。向量搜索的分数处于 0-1.0 之间,而全文搜索的结果排名分数可能是高于10或者更大的值。我们需要一种方法把两种搜索方法的结果进行综合处理,并得出一个唯一的排名。
倒数排序融合(RRF)是一种将具有不同相关性指标的多个结果集组合成单个结果集的方法。 
RRF 无需调优,不同的相关性指标也不必相互关联即可获得高质量的结果。该方法的优势在于不利用相关分数,而仅靠排名计算。相关分数存在的问题在于不同模型的分数范围差。
针对不同的 RAG 任务有不同的处理方式比如 法律、历史、人文类型的任务还可以加入命名实体识别 。或者使用其他语义转换模型将长文本总结为短文本。将拆分的长文本先调用 embed 转为向量后存储到 index 上。然后执行混合检索。

3.Embedding

第一步是文本向量化,这一步可以放在客户端做也可以放在 Elasticsearch 服务端做,不过模型推理是 Elasticsearch 新版中的重大功能,下面演示如何做。

在抱脸上直接搜索  sentence-similarity 模型,最靠前的就是 bge 由智源开源,基本上从去年开始一直是榜一,输入 zh 筛选中文:

使用 langchain 测试推理,模型输出是  dim=1024:

ElasticSearch支持最大 2048,目前 Es 还不支持非固定长度的向量,Elasticsearch 提供了 Eland 工具用于 pytorch 模型的推理和上传,源码安装该工具:

git clone https://github.com/elastic/eland
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
python setup.py install

然后执行上传脚本:
eland_import_hub_model --url http://192.168.197.128:9200 --hub-model-id .\Langchain-Chatchat-0.2.10\model\bge-large-zh-v1.5 --task-type text_embedding --start --clear-previous

上传过程不太顺利发现源码有一些问题需要修改,大致两处:
eland_import_hub_model.py  =>    上传前会把模型和一些文件放到临时目录,因为我的 windwos user name 是中文会找不到路径。直接将 tmp 写死即可。
            with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_dir = 'C:\\tmp' 
transformers.py    =>        函数里面将 token 这个参数去掉
            # model = AutoModel.from_pretrained(model_id, token=token, torchscript=True)
            model = AutoModel.from_pretrained(model_id, torchscript=True)

等待执行完成:

上传成功后在 kibana 模型管理位置点击 Synchronize your jobs and trained models.,同步一下刚刚上传的模型看到,调用推理接口,复制模型id,可以看到模型输出和前面 embed_demo.py 中测试的一样:
POST _ml/trained_models/m_workspace__langchain-chatchat-0.2.10__model__bge-large-zh-v1.5/_infer
{
  "docs": [
    {"text_field": "你好,请问你在干什么?"}
  ]
}

4.文本分割

向量 dim=1024 是无法将一个超长文本完整的语义全部嵌入的,且大模型 token 的限制需要将文档进行分割,最简单的做法是指定 chunk_size(单个文档token数) 和 chunk_overlap(向量文档重叠token数)对文档进行分割,也有按句分割的做法,更加准确的是使用现成的语义分割模型,可以看看 github 上 Langchain-Chatchat 这个项目,提供了多种分割方式:

5.部署 LLM 做增强生成

对于 RAG 任务,更大参数量的 LLM 对效果并没有显著提升, 即使是最小参数量的大模型也涵盖了基本的理解能力,这里部署清华 ChatGLM-6b  int4 量化模型 6G显存就够,这样可以将 token 开到很大。

git clone https://github.com/THUDM/ChatGLM-6B
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
git clone https://huggingface.co/THUDM/chatglm-6b-int4

模型 README.md 中有测试代码,替换一下模型路径就可以了:

然后写一个 ELasticsearch Query 例子,根据搜索文档拼接 Prompt 做问答,Java 完整代码:

  1. package tool.elk;
  2. import com.alibaba.fastjson.JSON;
  3. import com.alibaba.fastjson.JSONArray;
  4. import com.alibaba.fastjson.JSONObject;
  5. import org.apache.http.HttpEntity;
  6. import org.apache.http.HttpHost;
  7. import org.apache.http.client.methods.CloseableHttpResponse;
  8. import org.apache.http.client.methods.HttpPost;
  9. import org.apache.http.entity.ContentType;
  10. import org.apache.http.entity.StringEntity;
  11. import org.apache.http.impl.client.HttpClients;
  12. import org.apache.http.nio.entity.NStringEntity;
  13. import org.apache.http.util.EntityUtils;
  14. import org.elasticsearch.client.*;
  15. import org.elasticsearch.client.indices.CreateIndexRequest;
  16. import org.elasticsearch.client.indices.GetIndexRequest;
  17. import org.elasticsearch.common.xcontent.XContentType;
  18. import java.io.BufferedReader;
  19. import java.io.FileReader;
  20. import java.nio.charset.StandardCharsets;
  21. import java.util.ArrayList;
  22. import java.util.List;
  23. /**
  24. * @desc : elatcisearch rag 测试
  25. * @auth : tyf
  26. * @date : 2024-04-16 10:06:24
  27. */
  28. public class RAGDemo {
  29. public static String es_host = "192.168.197.128";
  30. public static Integer es_port = 9200;
  31. public static String llm_host = "http://0.0.0.0:8000";
  32. public static RestHighLevelClient highLevelClient;
  33. public static RestClient lowLevelClient;
  34. static {
  35. String[] ipArr = es_host.split(",");
  36. HttpHost[] httpHosts = new HttpHost[ipArr.length];
  37. for (int i = 0; i < ipArr.length; i++) {
  38. httpHosts[i] = new HttpHost(ipArr[i], es_port, "http");
  39. }
  40. RestClientBuilder builder = RestClient.builder(httpHosts);
  41. highLevelClient = new RestHighLevelClient(builder);
  42. lowLevelClient = highLevelClient.getLowLevelClient();
  43. System.out.println("初始化成功");
  44. }
  45. // 索引名称
  46. public static String indexName = "doc_split";
  47. // 索引 mapping
  48. public static String indexMapping =
  49. "{\n" +
  50. " \"settings\": {\n" +
  51. " \"number_of_shards\": 1,\n" +
  52. " \"number_of_replicas\": 0\n" +
  53. " },\n" +
  54. " \"mappings\": {\n" +
  55. " \"properties\": {\n" +
  56. " \"content\": {\n" +
  57. " \"type\": \"text\"\n" +
  58. " },\n" +
  59. " \"timestamp\": {\n" +
  60. " \"type\": \"long\"\n" +
  61. " },\n" +
  62. " \"content_embed\": {\n" +
  63. " \"type\": \"dense_vector\",\n" +
  64. " \"dims\": 1024,\n" +
  65. " \"index\": true,\n" +
  66. " \"similarity\": \"cosine\"\n" +
  67. " }\n" +
  68. " }\n" +
  69. " }\n" +
  70. "}";
  71. // embed 模型编号
  72. public static String modelId = "m_workspace__langchain-chatchat-0.2.10__model__bge-large-zh-v1.5";
  73. // 文档召回 _score 阈值
  74. public static double scoreThreshold = 3d;
  75. // 本地文档路径
  76. public static String docPath = "C:\\Users\\唐于凡\\Desktop\\中华人民共和国民法典.txt";
  77. // 创建索引
  78. public static void createIndex() throws Exception{
  79. // System.out.println(indexMapping);
  80. // 索引不存在则创建
  81. GetIndexRequest request1 = new GetIndexRequest(indexName);
  82. boolean response1 = highLevelClient.indices().exists(request1, RequestOptions.DEFAULT);
  83. if(!response1){
  84. CreateIndexRequest request2 = new CreateIndexRequest(indexName);
  85. request2.source(indexMapping, XContentType.JSON);
  86. highLevelClient.indices().create(request2, RequestOptions.DEFAULT);
  87. }
  88. }
  89. // 读取并拆分文档
  90. public static List<String> parseDoc(int chunkSize,int chunkOverlap) throws Exception{
  91. List<String> splitTexts = new ArrayList<>();
  92. try (BufferedReader br = new BufferedReader(new FileReader(docPath))) {
  93. StringBuilder sb = new StringBuilder();
  94. String line;
  95. while ((line = br.readLine()) != null) {
  96. // 去掉没用的空格
  97. line = line.trim();
  98. if (!line.isEmpty()) {
  99. sb.append(line).append(" "); // 可以根据需要调整分隔符
  100. }
  101. }
  102. String fullText = sb.toString().trim();
  103. // 拆分文本
  104. for (int i = 0; i < fullText.length(); i += chunkSize - chunkOverlap) {
  105. if (i + chunkSize < fullText.length()) {
  106. splitTexts.add(fullText.substring(i, i + chunkSize));
  107. } else {
  108. splitTexts.add(fullText.substring(i));
  109. }
  110. }
  111. }
  112. System.out.println("文档总数:"+splitTexts.size());
  113. return splitTexts;
  114. }
  115. // 调用 embed 模型转为向量
  116. public static Object embedDoc(String text){
  117. Object rt = null;
  118. // POST
  119. try {
  120. String entity = "{ \"docs\": [{\"text_field\": \""+text+"\"}]}";
  121. Request req = new Request("POST","_ml/trained_models/"+modelId+"/_infer");
  122. HttpEntity params = new NStringEntity(entity, ContentType.APPLICATION_JSON);
  123. req.setEntity(params);
  124. Response rsp = lowLevelClient.performRequest(req);
  125. HttpEntity en = rsp.getEntity();
  126. String body = EntityUtils.toString(en);
  127. JSONObject data = JSON.parseObject(body);
  128. rt = data.getJSONArray("inference_results").getJSONObject(0).getJSONArray("predicted_value");
  129. }
  130. catch (Exception e){
  131. e.printStackTrace();
  132. }
  133. return rt;
  134. }
  135. // 提交 Elasticsearch
  136. public static void uploadDoc(List<String> docSplits) throws Exception{
  137. // 遍历每个文档
  138. for (int i = 0; i < docSplits.size(); i++) {
  139. // 原始文本
  140. String content = docSplits.get(i);
  141. // 转为向量
  142. Object content_embed = embedDoc(content);
  143. // 时间
  144. Long timestamp = System.currentTimeMillis();
  145. // 上传
  146. JSONObject data = new JSONObject();
  147. data.put("content",content);
  148. data.put("content_embed",content_embed);
  149. data.put("timestamp",timestamp);
  150. Request req = new Request("POST","/"+indexName+"/_doc");
  151. HttpEntity params = new NStringEntity(data.toJSONString(), ContentType.APPLICATION_JSON);
  152. req.setEntity(params);
  153. Response res = lowLevelClient.performRequest(req);
  154. System.out.println("上传第"+i+"条:"+res);
  155. }
  156. }
  157. // 执行混合检索
  158. public static List<String> search(String q) throws Exception{
  159. // 转为向量
  160. Object vector = embedDoc(q);
  161. // 查询语句
  162. String query =
  163. "{\n" +
  164. " \"query\": {\n" +
  165. " \"bool\": {\n" +
  166. " \"must\": [\n" +
  167. " {\n" +
  168. " \"match\": {\n" +
  169. " \"content\": {\n" +
  170. " \"query\": \""+q+"\",\n" +
  171. " \"boost\": 1\n" +
  172. " }\n" +
  173. " }\n" +
  174. " }\n" +
  175. " ]\n" +
  176. " }\n" +
  177. " },\n" +
  178. " \"knn\": {\n" +
  179. " \"field\": \"content_embed\",\n" +
  180. " \"k\": 5,\n" +
  181. " \"num_candidates\": 100,\n" +
  182. " \"query_vector\": "+vector+"\n" +
  183. " },\n" +
  184. " \"size\": 5\n" +
  185. "}\n";
  186. // System.out.println("查询语句:");
  187. // System.out.println(query);
  188. // 调用查询
  189. Request req = new Request("POST","/"+indexName+"/_search?pretty");
  190. HttpEntity params = new NStringEntity(query, ContentType.APPLICATION_JSON);
  191. req.setEntity(params);
  192. Response res = lowLevelClient.performRequest(req);
  193. // 解析
  194. String body = EntityUtils.toString(res.getEntity());
  195. JSONArray data = JSON.parseObject(body).getJSONObject("hits").getJSONArray("hits");
  196. // 遍历每个文档、将高的分的文档保存
  197. List<String> contents = new ArrayList<>();
  198. data.stream().map(n->JSONObject.parseObject(n.toString())).forEach(n->{
  199. // 得分高的才作为资料避免 llm 幻觉
  200. Double _score = n.getDouble("_score");
  201. if(_score >= scoreThreshold){
  202. // 文本
  203. String content = n.getJSONObject("_source").getString("content");
  204. contents.add(content);
  205. System.out.println("召回文档数据:"+n);
  206. }
  207. });
  208. System.out.println();
  209. return contents;
  210. }
  211. // 拼接 prompt
  212. public static String prompt(List<String> content,String q){
  213. StringBuilder question = new StringBuilder();
  214. question.append("你好,下面是我搜索得到的资料:\n");
  215. if(content.size()==0){
  216. question.append("无。\n");
  217. }
  218. for (int i = 0; i < content.size() ; i++) {
  219. question.append("("+(i+1)+")").append(content.get(i)).append("\n");
  220. }
  221. question.append("\n");
  222. question.append("请帮我根据上面的资料分析下面的问题,并帮我根据资料列出相关依据:\n");
  223. question.append(q).append("\n");
  224. question.append("\n");
  225. question.append("如果根据资料无法分析请回复不知道!");
  226. return question.toString();
  227. }
  228. // 调用 LLM 生成回答
  229. public static String llmAnswer(String question) throws Exception{
  230. JSONObject data = new JSONObject();
  231. data.put("prompt",question);
  232. data.put("history",null);
  233. HttpPost httpPost = new HttpPost(llm_host);
  234. httpPost.addHeader("Content-Type", "application/json;charset=utf-8");
  235. httpPost.setEntity(new StringEntity(data.toString(), StandardCharsets.UTF_8));
  236. CloseableHttpResponse response = HttpClients.createDefault().execute(httpPost);
  237. HttpEntity resEntity = response.getEntity();
  238. String resp = EntityUtils.toString(resEntity,"utf-8");
  239. return JSONObject.parseObject(resp).getString("response");
  240. }
  241. public static void main(String[] args) throws Exception{
  242. // 创建索引
  243. // createIndex();
  244. // 读取并拆分文档、提交 Elasticsearch
  245. // uploadDoc(parseDoc(500,100));
  246. // 执行混合检索
  247. String question = "结婚领证登记需要双发到场吗?";
  248. List<String> contents = search(question);
  249. // 执行混合检索并拼接 prompt
  250. String prompt = prompt(contents,question);
  251. // 调用 LLM 生成回答
  252. String answer = llmAnswer(prompt);
  253. System.out.println("-----------");
  254. System.out.println("Question:");
  255. System.out.println(question);
  256. System.out.println("-----------");
  257. System.out.println("Prompt:");
  258. System.out.println(prompt);
  259. System.out.println("-----------");
  260. System.out.println("Answer:");
  261. System.out.println(answer);
  262. }
  263. }


 
 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/972407
推荐阅读
相关标签
  

闽ICP备14008679号