赞
踩
第一章 Springboot RAG 一站式混合搜索方案
最近在做一个政策类查询的RAG方案,做成一站式可以快速使用的方案。
数据库是PG, PGVector作为向量数据库,采用Hybrid Search方法来同时匹配向量和其他字段。
项目采用Springboot 作为后端;大模型相关的,使用到的API有Chatgpt, moonshot, qwen,讯飞星火等不同厂家的方案。
该方案从产品方面来考虑,可扩展性,可便利性等没有太多考虑;从单个项目来说,算是一个可用的方案。
Spring AI 支持所有主要的模型提供商,如 OpenAI、Microsoft、Amazon、Google 和 Huggingface;国内的大模型还没有支持,国内大模型的API的返回,有几个是兼容OpenAI的,另外一些是不兼容的,需要做不少工作来完全兼容。这是后面可以优化的方向,做成一个统一的接口,便于系统维护和更多人的上手使用。
系统的数据库采用PG, 文本也是放在一个text字段中,用PG自带的全文检索,同时把向量匹配也放到一起过滤,所以向量数据库采用PGVector。
PGVector的安装有很多写的详细的过程,这里略过。
PostgreSQL 全文检索:PostgreSQL 自带的全文检索功能可以使用 tsvector 和 tsquery 数据类型,通过分词和倒排索引来实现语义搜索的基本需求。
Zhparser 分词插件:对于中文文本,可以使用 PostgreSQL 的 Zhparser 插件进行中文分词,结合全文检索功能实现语义搜索。
系统里面使用SpringBoot, MyBatisPlus的方式来,目前MyBatisPlus本身并不支持PGVector。需要增加一个PGVector的方式,
PGVector有项目https://github.com/pgvector/pgvector-java.git, 实现过程中参考该项目来实现PGVector;(也可以参考SpringAI项目对PGVector支持)
package com.md.gpt.vector; import org.postgresql.PGConnection; import org.postgresql.util.ByteConverter; import org.postgresql.util.PGBinaryObject; import org.postgresql.util.PGobject; import java.io.Serializable; import java.sql.Connection; import java.sql.SQLException; import java.util.Arrays; import java.util.List; import java.util.Objects; public class PGvector extends PGobject implements PGBinaryObject, Serializable, Cloneable { private float[] vec; /** * Constructor */ public PGvector() { type = "vector"; } /** * Constructor * * @param v float array */ public PGvector(float[] v) { this(); vec = v; } /** * Constructor * * @param <T> number * @param v list of numbers */ public <T extends Number> PGvector(List<T> v) { this(); if (Objects.isNull(v)) { vec = null; } else { vec = new float[v.size()]; int i = 0; for (T f : v) { vec[i++] = f.floatValue(); } } } /** * Constructor * * @param s text representation of a vector * @throws SQLException exception */ public PGvector(String s) throws SQLException { this(); setValue(s); } /** * Sets the value from a text representation of a vector */ @Override public void setValue(String s) throws SQLException { if (s == null) { vec = null; } else { String[] sp = s.substring(1, s.length() - 1).split(","); vec = new float[sp.length]; for (int i = 0; i < sp.length; i++) { vec[i] = Float.parseFloat(sp[i]); } } } /** * Returns the text representation of a vector */ @Override public String getValue() { if (vec == null) { return null; } else { return Arrays.toString(vec).replace(" ", ""); } } /** * Returns the number of bytes for the binary representation */ @Override public int lengthInBytes() { return vec == null ? 0 : 4 + vec.length * 4; } /** * Sets the value from a binary representation of a vector */ @Override public void setByteValue(byte[] value, int offset) throws SQLException { int dim = ByteConverter.int2(value, offset); int unused = ByteConverter.int2(value, offset + 2); if (unused != 0) { throw new SQLException("expected unused to be 0"); } vec = new float[dim]; for (int i = 0; i < dim; i++) { vec[i] = ByteConverter.float4(value, offset + 4 + i * 4); } } /** * Writes the binary representation of a vector */ @Override public void toBytes(byte[] bytes, int offset) { if (vec == null) { return; } // server will error on overflow due to unconsumed buffer // could set to Short.MAX_VALUE for friendlier error message ByteConverter.int2(bytes, offset, vec.length); ByteConverter.int2(bytes, offset + 2, 0); for (int i = 0; i < vec.length; i++) { ByteConverter.float4(bytes, offset + 4 + i * 4, vec[i]); } } /** * Returns an array * * @return an array */ public float[] toArray() { return vec; } /** * Registers the vector type * * @param conn connection * @throws SQLException exception */ public static void addVectorType(Connection conn) throws SQLException { conn.unwrap(PGConnection.class).addDataType("vector", PGvector.class); } }
PGvectorTypeHandler 类主要如下 import com.md.gpt.util.FloatUtil; import com.md.gpt.vector.PGvector; import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.type.BaseTypeHandler; import org.apache.ibatis.type.JdbcType; import org.apache.ibatis.type.MappedJdbcTypes; import org.apache.ibatis.type.MappedTypes; import java.sql.*; import java.util.Arrays; import static org.springframework.util.ObjectUtils.toObjectArray; @Slf4j @MappedTypes({PGvector.class}) public class PGvectorTypeHandler extends BaseTypeHandler<PGvector> { @Override public void setNonNullParameter(PreparedStatement ps, int i, PGvector parameter, JdbcType jdbcType) throws SQLException { log.info("Getting PGvector result by column name: {}", parameter.toString()); Connection conn = ps.getConnection(); Float[] boxedArray = FloatUtil.toObjectArray(parameter.toArray()); Array sqlArray = conn.createArrayOf("float", boxedArray); ps.setArray(i, sqlArray); } @Override public PGvector getNullableResult(ResultSet rs, String columnName) throws SQLException { log.info("Getting PGvector result by column name: {}", columnName); Array array = rs.getArray(columnName); if (array != null) { Float[] javaArray = (Float[])array.getArray(); return new PGvector(FloatUtil.toPrimitiveArray(javaArray)); } return null; } @Override public PGvector getNullableResult(ResultSet rs, int columnIndex) throws SQLException { log.info("Getting PGvector result by column index: {}", columnIndex); return (PGvector) rs.getObject(columnIndex); } @Override public PGvector getNullableResult(CallableStatement cs, int columnIndex) throws SQLException { log.info("Getting PGvector result by column index: {}", columnIndex); return (PGvector) cs.getObject(columnIndex); } }
MyBatisConfig.java对应修改
@Bean public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception { log.info("Registering PGvectorTypeHandler in sqlSessionFactory"); MybatisSqlSessionFactoryBean sqlSessionFactoryBean = new MybatisSqlSessionFactoryBean(); sqlSessionFactoryBean.setDataSource(dataSource); // 注册 MyBatis Plus 拦截器 sqlSessionFactoryBean.setPlugins(mybatisPlusInterceptor()); // Set location of Mapper XML files sqlSessionFactoryBean.setMapperLocations( new PathMatchingResourcePatternResolver().getResources("classpath:/mapper/*.xml") ); // 创建 MybatisConfiguration 实例并配置 MybatisConfiguration configuration = new MybatisConfiguration(); configuration.addMappers("com.md.gpt.mapper"); // configuration.addMapperLocation("classpath:mapper/*.xml"); // 例如,注册自定义类型处理器 configuration.getTypeHandlerRegistry().register(PGvectorTypeHandler.class); sqlSessionFactoryBean.setConfiguration(configuration); return sqlSessionFactoryBean.getObject(); } @Bean public TypeHandlerRegistry typeHandlerRegistry() { log.info("Registering PGvectorTypeHandler in TypeHandlerRegistry"); TypeHandlerRegistry registry = new TypeHandlerRegistry(); // 注册自定义类型处理器 registry.register(PGvectorTypeHandler.class); return registry; }
PGvectorConfiguration.java
import com.md.gpt.mybatis.PGvectorTypeHandler; import com.md.gpt.vector.PGvector; import lombok.extern.slf4j.Slf4j; import org.springframework.boot.CommandLineRunner; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import javax.sql.DataSource; import java.sql.Connection; import java.sql.SQLException; @Configuration @Slf4j public class PGvectorConfiguration { @Bean CommandLineRunner registerPGvectorType(DataSource dataSource) { return args -> { try (Connection conn = dataSource.getConnection()) { log.info("Registering PGvector type with the database"); PGvector.addVectorType(conn); // This registers the 'vector' type } catch (SQLException e) { throw new RuntimeException("Failed to register PGvector type with the database", e); } }; } }
FaguikuMapper.java
List<Faguiku> findNewClosestEmbeddings(@Param("embedding") PGvector embedding, String yearInfo,
String title, String keywords, String region, String docNumber,
String startDate, String endDate, String status, String fileType,
String docDanwei, String fileTaxType,int pageSize, int offset);
FaguikuMapper.xml
<select id="findNewClosestEmbeddings" resultMap="BaseResultMap"> SELECT id, title, doc_number, date_written, status, link, attachment_link, file_type, doc_danwei, file_tax_type, status_sort_order FROM faguiku <where> <if test="yearInfo != null and yearInfo.trim() != ''"> title LIKE CONCAT('%', #{yearInfo}, '%') OR doc_number LIKE CONCAT('%', #{yearInfo}, '%') </if> <if test="title != null and title.trim() != ''"> AND to_tsvector('zh_cn', title) @@ plainto_tsquery('zh_cn', #{title}) </if> <if test="keywords != null and keywords.trim() != ''"> AND to_tsvector('zh_cn', content2) @@ plainto_tsquery('zh_cn', #{keywords}) </if> <if test="region != null and region.trim() != ''"> AND REPLACE(REPLACE(source, '省', ''), '市', '') = REPLACE(REPLACE(#{region}, '省', ''), '市', '') </if> <if test="docNumber != null and docNumber.trim() != ''"> AND doc_number LIKE CONCAT('%', #{docNumber}, '%') </if> <if test="status != null and status.trim() != ''"> AND status = #{status} </if> <if test="docDanwei != null and docDanwei.trim() != ''"> AND doc_danwei LIKE CONCAT('%', #{docDanwei}, '%') </if> <if test="fileTaxType != null and fileTaxType.trim() != ''"> AND file_tax_type = #{fileTaxType} </if> <if test="startDate != null and startDate.trim() != ''"> AND date_written >= #{startDate} </if> <if test="endDate != null and endDate.trim() != ''"> AND date_written <= #{endDate} </if> </where> ORDER BY embeddings::vector <![CDATA[ <=> ]]> #{embedding}::vector, status_sort_order, date_written desc LIMIT #{pageSize} OFFSET #{offset} </select>
embeddings 通过Embedding API去获取,可以通过各个大模型的Embedding API,也可以通过部署Embeddings API, 根据huggingface中的排名来选择中文支持较好的Embedding模型
这种方案,较好的考虑到了查询的方便性,对需要管理的文档,统一在数据库中管理;否则如果有成千上万篇文档,而不能有效的通过系统管理起来,那几乎很难维护了。通过Hybrid Search,结合传统数据库对于一些字段的完全匹配,结合全文搜索和向量搜索,得出来的结果,可以根据查询结果来调整,把符合条件的都过滤出来给用户做下一步分析使用。
有几个待进一步探讨探讨的问题:
1.文本分段,文章内容来源于很多地方,通过按照chunksize overlap来分段,然后再embedding, 有些自然段落被划分错了。通过调用大模型来划分,Token消耗比较多,时间比较久。目前使用的库里面,用到了3000万Token。而且有些划分还是不太合理,需要人工介入才能划分的合理。
2. Rerank模型,因为排序在sql里面做了,所以没有用rerank模型,准备按照rerank模型的算法看看能不能本地部署或者实现,来看看rerank能提高多少
3. 查询的文档比较多,最后给大模型来整理的话,一个是传入的长度有限制,不能把所有结果都传过去;二是太长的查询,消耗的也多。
4. 其他RAG方面的最新论文探索,比如Graph RAG等,后续进一步研究
5. 其他问题,欢迎加微信交流 rogerlzp
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。