当前位置:   article > 正文

Milvus 基本操作_milvus milvusclient.search(searchparam) 评分

milvus milvusclient.search(searchparam) 评分

1、maven 依赖

  1. <dependency>
  2. <groupId>io.milvus</groupId>
  3. <artifactId>milvus-sdk-java</artifactId>
  4. <version>2.3.3</version>
  5. <exclusions>
  6. <exclusion>
  7. <groupId>org.slf4j</groupId>
  8. <artifactId>slf4j-api</artifactId>
  9. </exclusion>
  10. <exclusion>
  11. <groupId>org.apache.logging.log4j</groupId>
  12. <artifactId>log4j-slf4j-impl</artifactId>
  13. </exclusion>
  14. </exclusions>
  15. </dependency>

2、MivusService 封装了 基本操作

  1. @Service
  2. @Slf4j
  3. public class MivusService {
  4. @Autowired
  5. MilvusServiceClient milvusClient;
  6. private String clientId;
  7. /**
  8. * 同步搜索milvus
  9. * @param collectionName 表名
  10. * @param vectors 查询向量
  11. * @param topK 最相似的向量个数
  12. * @return
  13. */
  14. public List<Long> search(String collectionName, List<List<Float>> vectors, Integer topK) {
  15. Assert.notNull(collectionName, "collectionName is null");
  16. Assert.notNull(vectors, "vectors is null");
  17. Assert.notEmpty(vectors, "vectors is empty");
  18. Assert.notNull(topK, "topK is null");
  19. int nprobeVectorSize = vectors.get(0).size();
  20. String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
  21. SearchParam searchParam =
  22. SearchParam.newBuilder().withCollectionName(collectionName)
  23. .withParams(paramsInJson)
  24. .withMetricType(MetricType.L2)
  25. .withVectors(vectors)
  26. .withVectorFieldName("embeddings")
  27. .withTopK(topK)
  28. .build();
  29. R<SearchResults> searchResultsR = milvusClient.search(searchParam);
  30. SearchResults searchResultsRData = searchResultsR.getData();
  31. List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
  32. return topksList;
  33. }
  34. /**
  35. * 同步搜索milvus
  36. * @param collectionName 表名
  37. * @param vectors 查询向量
  38. * @param topK 最相似的向量个数
  39. * @return
  40. */
  41. public List<Long> search1(String collectionName, List<List<Float>> vectors, Integer topK) {
  42. Assert.notNull(collectionName, "collectionName is null");
  43. Assert.notNull(vectors, "vectors is null");
  44. Assert.notEmpty(vectors, "vectors is empty");
  45. Assert.notNull(topK, "topK is null");
  46. int nprobeVectorSize = vectors.get(0).size();
  47. String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
  48. SearchParam searchParam =
  49. SearchParam.newBuilder().withCollectionName(collectionName)
  50. .withParams(paramsInJson)
  51. .withMetricType(MetricType.IP)
  52. .withVectors(vectors)
  53. .withVectorFieldName("embedding")
  54. .withTopK(topK)
  55. .build();
  56. R<SearchResults> searchResultsR = milvusClient.search(searchParam);
  57. SearchResults searchResultsRData = searchResultsR.getData();
  58. List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
  59. return topksList;
  60. }
  61. /**
  62. * 同步搜索milvus,增加过滤条件搜索
  63. *
  64. * @param collectionName 表名
  65. * @param vectors 查询向量
  66. * @param topK 最相似的向量个数
  67. * @param exp 过滤条件:status=1
  68. * @return
  69. */
  70. public List<Long> search2(String collectionName, List<List<Float>> vectors, Integer topK, String exp) {
  71. Assert.notNull(collectionName, "collectionName is null");
  72. Assert.notNull(vectors, "vectors is null");
  73. Assert.notEmpty(vectors, "vectors is empty");
  74. Assert.notNull(topK, "topK is null");
  75. Assert.notNull(exp, "exp is null");
  76. int nprobeVectorSize = vectors.get(0).size();
  77. String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
  78. SearchParam searchParam =
  79. SearchParam.newBuilder().withCollectionName(collectionName)
  80. .withParams(paramsInJson)
  81. .withMetricType(MetricType.IP)
  82. .withVectors(vectors)
  83. .withExpr(exp)
  84. .withVectorFieldName("embedding")
  85. .withTopK(topK)
  86. .build();
  87. R<SearchResults> searchResultsR = milvusClient.search(searchParam);
  88. SearchResults searchResultsRData = searchResultsR.getData();
  89. List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
  90. return topksList;
  91. }
  92. /**
  93. * 异步搜索milvus
  94. *
  95. * @param collectionName 表名
  96. * @param vectors 查询向量
  97. * @param partitionList 最相似的向量个数
  98. * @param topK
  99. * @return
  100. */
  101. public List<Long> searchAsync(String collectionName, List<List<Float>> vectors,
  102. List<String> partitionList, Integer topK) throws ExecutionException, InterruptedException {
  103. Assert.notNull(collectionName, "collectionName is null");
  104. Assert.notNull(vectors, "vectors is null");
  105. Assert.notEmpty(vectors, "vectors is empty");
  106. Assert.notNull(partitionList, "partitionList is null");
  107. Assert.notEmpty(partitionList, "partitionList is empty");
  108. Assert.notNull(topK, "topK is null");
  109. int nprobeVectorSize = vectors.get(0).size();
  110. String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
  111. SearchParam searchParam =
  112. SearchParam.newBuilder().withCollectionName(collectionName)
  113. .withParams(paramsInJson)
  114. .withVectors(vectors)
  115. .withTopK(topK)
  116. .withPartitionNames(partitionList)
  117. .build();
  118. ListenableFuture<R<SearchResults>> listenableFuture = milvusClient.searchAsync(searchParam);
  119. List<Long> resultIdsList = listenableFuture.get().getData().getResults().getTopksList();
  120. return resultIdsList;
  121. }
  122. /**
  123. * 获取分区集合
  124. * @param collectionName 表名
  125. * @return
  126. */
  127. public List<String> getPartitionsList(String collectionName) {
  128. Assert.notNull(collectionName, "collectionName is null");
  129. ShowPartitionsParam searchParam = ShowPartitionsParam.newBuilder().withCollectionName(collectionName).build();
  130. List<ByteString> byteStrings = milvusClient.showPartitions(searchParam).getData().getPartitionNamesList().asByteStringList();
  131. List<String> partitionList = Lists.newLinkedList();
  132. byteStrings.forEach(s -> {
  133. partitionList.add(s.toStringUtf8());
  134. });
  135. return partitionList;
  136. }
  137. public void loadCollection(String collectionName) {
  138. LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
  139. .withCollectionName(collectionName)
  140. .build();
  141. R<RpcStatus> response = milvusClient.loadCollection(loadCollectionParam);
  142. log.info("loadCollection {} is {}", collectionName, response.getData().getMsg());
  143. }
  144. public void releaseCollection(String collectionName) {
  145. ReleaseCollectionParam param = ReleaseCollectionParam.newBuilder()
  146. .withCollectionName(collectionName)
  147. .build();
  148. R<RpcStatus> response = milvusClient.releaseCollection(param);
  149. log.info("releaseCollection {} is {}", collectionName, response.getData().getMsg());
  150. }
  151. public void loadPartitions(String collectionName, List<String> partitionsName) {
  152. LoadPartitionsParam build = LoadPartitionsParam.newBuilder()
  153. .withCollectionName(collectionName)
  154. .withPartitionNames(partitionsName)
  155. .build();
  156. R<RpcStatus> rpcStatusR = milvusClient.loadPartitions(build);
  157. log.info("loadPartitions {} is {}", partitionsName, rpcStatusR.getData().getMsg());
  158. }
  159. public void releasePartitions(String collectionName, List<String> partitionsName) {
  160. ReleasePartitionsParam build = ReleasePartitionsParam.newBuilder()
  161. .withCollectionName(collectionName)
  162. .withPartitionNames(partitionsName)
  163. .build();
  164. R<RpcStatus> rpcStatusR = milvusClient.releasePartitions(build);
  165. log.info("releasePartition {} is {}", collectionName, rpcStatusR.getData().getMsg());
  166. }
  167. public boolean isExitCollection(String collectionName) {
  168. HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
  169. .withCollectionName(collectionName)
  170. .build();
  171. R<Boolean> response = milvusClient.hasCollection(hasCollectionParam);
  172. Boolean isExists = response.getData();
  173. log.info("collection {} is exists: {}", collectionName, isExists);
  174. return isExists;
  175. }
  176. public Boolean creatCollection(String collectionName) {
  177. // 主键字段
  178. FieldType fieldType1 = FieldType.newBuilder()
  179. .withName(Content.Field.ID)
  180. .withDescription("primary key")
  181. .withDataType(DataType.Int64)
  182. .withPrimaryKey(true)
  183. .withAutoID(true)
  184. .build();
  185. // 文本字段
  186. FieldType fieldType2 = FieldType.newBuilder()
  187. .withName(Content.Field.CONTENT)
  188. .withDataType(DataType.VarChar)
  189. .withMaxLength(Content.MAX_LENGTH)
  190. .build();
  191. // 向量字段
  192. FieldType fieldType3 = FieldType.newBuilder()
  193. .withName(Content.Field.CONTENT_VECTOR)
  194. .withDataType(DataType.FloatVector)
  195. .withDimension(Content.FEATURE_DIM)
  196. .build();
  197. // 创建collection
  198. CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
  199. .withCollectionName(collectionName)
  200. .withDescription("Schema of Content")
  201. .withShardsNum(Content.SHARDS_NUM)
  202. .addFieldType(fieldType1)
  203. .addFieldType(fieldType2)
  204. .addFieldType(fieldType3)
  205. .build();
  206. R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
  207. log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
  208. return response.getData().getMsg().equals("Success");
  209. }
  210. public Boolean dropCollection(String collectionName) {
  211. DropCollectionParam book = DropCollectionParam.newBuilder()
  212. .withCollectionName(collectionName)
  213. .build();
  214. R<RpcStatus> response = milvusClient.dropCollection(book);
  215. return response.getData().getMsg().equals("Success");
  216. }
  217. public void createPartition(String collectionName, String partitionName) {
  218. CreatePartitionParam param = CreatePartitionParam.newBuilder()
  219. .withCollectionName(collectionName)
  220. .withPartitionName(partitionName)
  221. .build();
  222. R<RpcStatus> partition = milvusClient.createPartition(param);
  223. String msg = partition.getData().getMsg();
  224. log.info("create partition: {} in collection: {} is: {}", partition, collectionName, msg);
  225. }
  226. public Boolean createIndex(String collectionName) {
  227. // IndexType
  228. final IndexType INDEX_TYPE = IndexType.IVF_FLAT;
  229. // ExtraParam 建议值为 4 × sqrt(n), 其中 n 指 segment 最多包含的 entity 条数。
  230. final String INDEX_PARAM = "{\"nlist\":16384}";
  231. long startIndexTime = System.currentTimeMillis();
  232. R<RpcStatus> response = milvusClient.createIndex(CreateIndexParam.newBuilder()
  233. .withCollectionName(collectionName)
  234. .withIndexName(Content.CONTENT_INDEX)
  235. .withFieldName(Content.Field.CONTENT_VECTOR)
  236. .withMetricType(MetricType.L2)
  237. .withIndexType(INDEX_TYPE)
  238. .withExtraParam(INDEX_PARAM)
  239. .withSyncMode(Boolean.TRUE)
  240. .withSyncWaitingInterval(500L)
  241. .withSyncWaitingTimeout(30L)
  242. .build());
  243. long endIndexTime = System.currentTimeMillis();
  244. log.info("Succeed in " + (endIndexTime - startIndexTime) / 1000.00 + " seconds!");
  245. log.info("createIndex --->>> {} ", response.toString());
  246. GetIndexBuildProgressParam build = GetIndexBuildProgressParam.newBuilder()
  247. .withCollectionName(collectionName)
  248. .build();
  249. R<GetIndexBuildProgressResponse> idnexResp = milvusClient.getIndexBuildProgress(build);
  250. log.info("getIndexBuildProgress --->>> {}", idnexResp.getStatus());
  251. return response.getData().getMsg().equals("Success");
  252. }
  253. public ReplyMsg insert(String collectionName, List<InsertParam.Field> fields) {
  254. InsertParam insertParam = InsertParam.newBuilder()
  255. .withCollectionName(collectionName)
  256. .withFields(fields)
  257. .build();
  258. R<MutationResult> mutationResultR = milvusClient.insert(insertParam);
  259. log.info("Flushing...");
  260. long startFlushTime = System.currentTimeMillis();
  261. milvusClient.flush(FlushParam.newBuilder()
  262. .withCollectionNames(Collections.singletonList(collectionName))
  263. .withSyncFlush(true)
  264. .withSyncFlushWaitingInterval(50L)
  265. .withSyncFlushWaitingTimeout(30L)
  266. .build());
  267. long endFlushTime = System.currentTimeMillis();
  268. log.info("Succeed in " + (endFlushTime - startFlushTime) / 1000.00 + " seconds!");
  269. if (mutationResultR.getStatus() == 0){
  270. long insertCnt = mutationResultR.getData().getInsertCnt();
  271. log.info("Successfully! Total number of entities inserted: {} ", insertCnt);
  272. return ReplyMsg.ofSuccess("success", insertCnt);
  273. }
  274. log.error("InsertRequest failed!");
  275. return ReplyMsg.ofErrorMsg("InsertRequest failed!");
  276. }
  277. public List<List<SearchResultVo>> searchTopKSimilarity(SearchParamVo searchParamVo) {
  278. log.info("------search TopK Similarity------");
  279. SearchParam searchParam = SearchParam.newBuilder()
  280. .withCollectionName(searchParamVo.getCollectionName())
  281. .withMetricType(MetricType.L2)
  282. .withOutFields(searchParamVo.getOutputFields())
  283. .withTopK(searchParamVo.getTopK())
  284. .withVectors(searchParamVo.getQueryVectors())
  285. .withVectorFieldName(Content.Field.CONTENT_VECTOR)
  286. .withParams(searchParamVo.getParams())
  287. .build();
  288. R<SearchResults> respSearch = milvusClient.search(searchParam);
  289. if (respSearch.getData() == null) {
  290. return null;
  291. }
  292. log.info("------ process query results ------");
  293. SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
  294. List<List<SearchResultVo>> result = new ArrayList<>();
  295. for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
  296. List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
  297. List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
  298. List<SearchResultVo> list = new ArrayList<>();
  299. for (int j = 0; j < scores.size(); ++j) {
  300. SearchResultsWrapper.IDScore score = scores.get(j);
  301. QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
  302. long longID = score.getLongID();
  303. float distance = score.getScore();
  304. String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
  305. log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
  306. log.info("Content: " + content);
  307. list.add(SearchResultVo.builder().id(longID).score(distance).conent(content).build());
  308. }
  309. result.add(list);
  310. }
  311. log.info("Successfully!");
  312. return result;
  313. }
  314. public Boolean creatCollectionERP(String collectionName) {
  315. // 主键字段
  316. FieldType fieldType1 = FieldType.newBuilder()
  317. .withName(Content.Field.ID)
  318. .withDescription("primary key")
  319. .withDataType(DataType.Int64)
  320. .withPrimaryKey(true)
  321. .withAutoID(true)
  322. .build();
  323. // 文本字段
  324. FieldType fieldType2 = FieldType.newBuilder()
  325. .withName(Content.Field.CONTENT)
  326. .withDataType(DataType.VarChar)
  327. .withMaxLength(Content.MAX_LENGTH)
  328. .build();
  329. // 向量字段
  330. FieldType fieldType3 = FieldType.newBuilder()
  331. .withName(Content.Field.CONTENT_VECTOR)
  332. .withDataType(DataType.FloatVector)
  333. .withDimension(Content.FEATURE_DIM)
  334. .build();
  335. FieldType fieldType4 = FieldType.newBuilder()
  336. .withName(Content.Field.CONTENT_ANSWER)
  337. .withDataType(DataType.VarChar)
  338. .withMaxLength(Content.MAX_LENGTH)
  339. .build();
  340. FieldType fieldType5 = FieldType.newBuilder()
  341. .withName(Content.Field.TITLE)
  342. .withDataType(DataType.VarChar)
  343. .withMaxLength(Content.MAX_LENGTH)
  344. .build();
  345. FieldType fieldType6 = FieldType.newBuilder()
  346. .withName(Content.Field.PARAM)
  347. .withDataType(DataType.VarChar)
  348. .withMaxLength(Content.MAX_LENGTH)
  349. .build();
  350. FieldType fieldType7 = FieldType.newBuilder()
  351. .withName(Content.Field.TYPE)
  352. .withDataType(DataType.VarChar)
  353. .withMaxLength(Content.MAX_LENGTH)
  354. .build();
  355. // 创建collection
  356. CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
  357. .withCollectionName(collectionName)
  358. .withDescription("Schema of Content ERP")
  359. .withShardsNum(Content.SHARDS_NUM)
  360. .addFieldType(fieldType1)
  361. .addFieldType(fieldType2)
  362. .addFieldType(fieldType3)
  363. .addFieldType(fieldType4)
  364. .addFieldType(fieldType5)
  365. .addFieldType(fieldType6)
  366. .addFieldType(fieldType7)
  367. .build();
  368. R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
  369. log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
  370. return response.getData().getMsg().equals("Success");
  371. }
  372. public Boolean creatCollectionERPCLIP(String collectionName) {
  373. // 主键字段
  374. FieldType fieldType1 = FieldType.newBuilder()
  375. .withName(Content.Field.ID)
  376. .withDescription("primary key")
  377. .withDataType(DataType.Int64)
  378. .withPrimaryKey(true)
  379. .withAutoID(true)
  380. .build();
  381. // 文本字段
  382. FieldType fieldType2 = FieldType.newBuilder()
  383. .withName(Content.Field.CONTENT)
  384. .withDataType(DataType.VarChar)
  385. .withMaxLength(Content.MAX_LENGTH)
  386. .build();
  387. // 向量字段
  388. FieldType fieldType3 = FieldType.newBuilder()
  389. .withName(Content.Field.CONTENT_VECTOR)
  390. .withDataType(DataType.FloatVector)
  391. .withDimension(Content.FEATURE_DIM_CLIP)
  392. .build();
  393. FieldType fieldType4 = FieldType.newBuilder()
  394. .withName(Content.Field.CONTENT_ANSWER)
  395. .withDataType(DataType.VarChar)
  396. .withMaxLength(Content.MAX_LENGTH)
  397. .build();
  398. FieldType fieldType5 = FieldType.newBuilder()
  399. .withName(Content.Field.TITLE)
  400. .withDataType(DataType.VarChar)
  401. .withMaxLength(Content.MAX_LENGTH)
  402. .build();
  403. FieldType fieldType6 = FieldType.newBuilder()
  404. .withName(Content.Field.PARAM)
  405. .withDataType(DataType.VarChar)
  406. .withMaxLength(Content.MAX_LENGTH)
  407. .build();
  408. FieldType fieldType7 = FieldType.newBuilder()
  409. .withName(Content.Field.TYPE)
  410. .withDataType(DataType.VarChar)
  411. .withMaxLength(Content.MAX_LENGTH)
  412. .build();
  413. FieldType fieldType8 = FieldType.newBuilder()
  414. .withName(Content.Field.LABEL)
  415. .withDataType(DataType.VarChar)
  416. .withMaxLength(Content.MAX_LENGTH)
  417. .build();
  418. // 创建collection
  419. CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
  420. .withCollectionName(collectionName)
  421. .withDescription("Schema of Content ERP")
  422. .withShardsNum(Content.SHARDS_NUM)
  423. .addFieldType(fieldType1)
  424. .addFieldType(fieldType2)
  425. .addFieldType(fieldType3)
  426. .addFieldType(fieldType4)
  427. .addFieldType(fieldType5)
  428. .addFieldType(fieldType6)
  429. .addFieldType(fieldType7)
  430. .addFieldType(fieldType8)
  431. .build();
  432. R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
  433. log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
  434. return response.getData().getMsg().equals("Success");
  435. }
  436. public Boolean creatCollectionERPNLP(String collectionName) {
  437. // 主键字段
  438. FieldType fieldType1 = FieldType.newBuilder()
  439. .withName(Content.Field.ID)
  440. .withDescription("primary key")
  441. .withDataType(DataType.Int64)
  442. .withPrimaryKey(true)
  443. .withAutoID(true)
  444. .build();
  445. // 文本字段
  446. FieldType fieldType2 = FieldType.newBuilder()
  447. .withName(Content.Field.CONTENT)
  448. .withDataType(DataType.VarChar)
  449. .withMaxLength(Content.MAX_LENGTH)
  450. .build();
  451. // 向量字段
  452. FieldType fieldType3 = FieldType.newBuilder()
  453. .withName(Content.Field.CONTENT_VECTOR)
  454. .withDataType(DataType.FloatVector)
  455. .withDimension(Content.FEATURE_DIM_CLIP)
  456. .build();
  457. FieldType fieldType4 = FieldType.newBuilder()
  458. .withName(Content.Field.CONTENT_ANSWER)
  459. .withDataType(DataType.VarChar)
  460. .withMaxLength(Content.MAX_LENGTH)
  461. .build();
  462. FieldType fieldType5 = FieldType.newBuilder()
  463. .withName(Content.Field.TITLE)
  464. .withDataType(DataType.VarChar)
  465. .withMaxLength(Content.MAX_LENGTH)
  466. .build();
  467. FieldType fieldType6 = FieldType.newBuilder()
  468. .withName(Content.Field.PARAM)
  469. .withDataType(DataType.VarChar)
  470. .withMaxLength(Content.MAX_LENGTH)
  471. .build();
  472. FieldType fieldType7 = FieldType.newBuilder()
  473. .withName(Content.Field.TYPE)
  474. .withDataType(DataType.VarChar)
  475. .withMaxLength(Content.MAX_LENGTH)
  476. .build();
  477. FieldType fieldType8 = FieldType.newBuilder()
  478. .withName(Content.Field.LABEL)
  479. .withDataType(DataType.VarChar)
  480. .withMaxLength(Content.MAX_LENGTH)
  481. .build();
  482. // 创建collection
  483. CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
  484. .withCollectionName(collectionName)
  485. .withDescription("Schema of Content ERP")
  486. .withShardsNum(Content.SHARDS_NUM)
  487. .addFieldType(fieldType1)
  488. .addFieldType(fieldType2)
  489. .addFieldType(fieldType3)
  490. .addFieldType(fieldType4)
  491. .addFieldType(fieldType5)
  492. .addFieldType(fieldType6)
  493. .addFieldType(fieldType7)
  494. .addFieldType(fieldType8)
  495. .build();
  496. R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
  497. log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
  498. return response.getData().getMsg().equals("Success");
  499. }
  500. public List<List<SearchERPResultVo>> searchERPTopKSimilarity(SearchERPParamVo searchParamVo) {
  501. log.info("------search ERP TopK Similarity------");
  502. SearchParam searchParam = SearchParam.newBuilder()
  503. .withCollectionName(searchParamVo.getCollectionName())
  504. .withMetricType(MetricType.L2)
  505. .withOutFields(searchParamVo.getOutputFields())
  506. .withTopK(searchParamVo.getTopK())
  507. .withVectors(searchParamVo.getQueryVectors())
  508. .withVectorFieldName(Content.Field.CONTENT_VECTOR)
  509. .withParams(searchParamVo.getParams())
  510. .build();
  511. R<SearchResults> respSearch = milvusClient.search(searchParam);
  512. if (respSearch.getData() == null) {
  513. return null;
  514. }
  515. log.info("------ process query results ------");
  516. SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
  517. List<List<SearchERPResultVo>> result = new ArrayList<>();
  518. for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
  519. List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
  520. List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
  521. List<SearchERPResultVo> list = new ArrayList<>();
  522. for (int j = 0; j < scores.size(); ++j) {
  523. SearchResultsWrapper.IDScore score = scores.get(j);
  524. QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
  525. long longID = score.getLongID();
  526. float distance = score.getScore();
  527. String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
  528. String contentAnswer = (String) rowRecord.get(searchParamVo.getOutputFields().get(1));
  529. String title = (String) rowRecord.get(searchParamVo.getOutputFields().get(2));
  530. log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
  531. log.info("Content: " + content);
  532. list.add(SearchERPResultVo.builder().id(longID).score(distance).content(content).contentAnswer(contentAnswer).title(title).build());
  533. }
  534. result.add(list);
  535. }
  536. log.info("Successfully!");
  537. return result;
  538. }
  539. public List<List<SearchNLPResultVo>> searchNLPTopKSimilarity(SearchNLPParamVo searchParamVo) {
  540. log.info("------search ERP TopK Similarity------");
  541. SearchParam searchParam = SearchParam.newBuilder()
  542. .withCollectionName(searchParamVo.getCollectionName())
  543. .withMetricType(MetricType.L2)
  544. .withOutFields(searchParamVo.getOutputFields())
  545. .withTopK(searchParamVo.getTopK())
  546. .withVectors(searchParamVo.getQueryVectors())
  547. .withVectorFieldName(Content.Field.CONTENT_VECTOR)
  548. .withParams(searchParamVo.getParams())
  549. .withExpr(searchParamVo.getExpr())
  550. .build();
  551. R<SearchResults> respSearch = milvusClient.search(searchParam);
  552. if (respSearch.getData() == null) {
  553. return null;
  554. }
  555. log.info("------ process query results ------");
  556. SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
  557. List<List<SearchNLPResultVo>> result = new ArrayList<>();
  558. for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
  559. List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
  560. List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
  561. List<SearchNLPResultVo> list = new ArrayList<>();
  562. for (int j = 0; j < scores.size(); ++j) {
  563. SearchResultsWrapper.IDScore score = scores.get(j);
  564. QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
  565. long longID = score.getLongID();
  566. float distance = score.getScore();
  567. String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
  568. String contentAnswer = (String) rowRecord.get(searchParamVo.getOutputFields().get(1));
  569. String title = (String) rowRecord.get(searchParamVo.getOutputFields().get(2));
  570. log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
  571. log.info("Content: " + content);
  572. list.add(SearchNLPResultVo.builder().id(longID).score(distance).content(content).contentAnswer(contentAnswer).title(title).build());
  573. }
  574. result.add(list);
  575. }
  576. log.info("Successfully!");
  577. return result;
  578. }
  579. }

3、测试用例 

MilvusServiceERPNLPTest

  1. @SpringBootTest(classes = {DataChatgptApplication.class}, webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
  2. public class MilvusServiceERPNLPTest {
  3. @Autowired
  4. MivusService milvusService;
  5. @Autowired
  6. MilvusClient milvusClient;
  7. @Test
  8. void isExitCollection() {
  9. boolean mediumArticles = milvusService.isExitCollection(Content.COLLECTION_NAME_NLP);
  10. Assertions.assertTrue(mediumArticles);
  11. }
  12. @Test
  13. void creatCollection() {
  14. Boolean created = milvusService.creatCollectionERPNLP(Content.COLLECTION_NAME_NLP);
  15. Assertions.assertTrue(created);
  16. }
  17. @Test
  18. void createIndex(){
  19. Boolean index = milvusService.createIndex(Content.COLLECTION_NAME_NLP);
  20. Assertions.assertTrue(index);
  21. }
  22. @Test
  23. public void insertVector(){
  24. List<String> sentenceList = new ArrayList<>();
  25. sentenceList.add("网址是多少");
  26. List<String> contentAnswerList = new ArrayList<>();
  27. contentAnswerList.add("/home.ashx");
  28. List<String> titleList = new ArrayList<>();
  29. titleList.add("网址");
  30. List<String> paramList = new ArrayList<>();
  31. paramList.add("");
  32. List<String> typeList = new ArrayList<>();
  33. typeList.add("0");
  34. List<String> labelList = new ArrayList<>();
  35. labelList.add("操作直达");
  36. PaddleNewTextVo paddleNewTextVo = null;
  37. try {
  38. paddleNewTextVo = getVectorsLists(sentenceList);
  39. if (paddleNewTextVo == null) {
  40. // 获取不到再重试下
  41. paddleNewTextVo = getVectorsLists(sentenceList);
  42. }
  43. List<List<Double>> vectors = paddleNewTextVo.getVector();
  44. List<List<Float>> floatVectors = new ArrayList<>();
  45. for (List<Double> innerList : vectors) {
  46. List<Float> floatInnerList = new ArrayList<>();
  47. for (Double value : innerList) {
  48. floatInnerList.add(value.floatValue());
  49. }
  50. floatVectors.add(floatInnerList);
  51. }
  52. // 2.准备插入向量数据库
  53. List<InsertParam.Field> fields = new ArrayList<>();
  54. fields.add(new InsertParam.Field(Content.Field.CONTENT, sentenceList));
  55. fields.add(new InsertParam.Field(Content.Field.CONTENT_VECTOR, floatVectors));
  56. fields.add(new InsertParam.Field(Content.Field.CONTENT_ANSWER, contentAnswerList));
  57. fields.add(new InsertParam.Field(Content.Field.TITLE, titleList));
  58. fields.add(new InsertParam.Field(Content.Field.PARAM, paramList));
  59. fields.add(new InsertParam.Field(Content.Field.TYPE, typeList));
  60. fields.add(new InsertParam.Field(Content.Field.LABEL, labelList));
  61. // 3.执行操作
  62. milvusService.insert(Content.COLLECTION_NAME_NLP, fields);
  63. } catch (ApiException e) {
  64. System.out.println(e.getMessage());
  65. } catch (IOException e) {
  66. throw new RuntimeException(e);
  67. }
  68. }
  69. private static PaddleNewTextVo getVectorsLists(List<String> sentenceList) throws IOException {
  70. String url = "http://192.168.1.243:6001/"; //paddle
  71. URL obj = new URL(url);
  72. HttpURLConnection con = (HttpURLConnection) obj.openConnection();
  73. // 设置超时时间
  74. con.setConnectTimeout(50000);
  75. con.setReadTimeout(200000);
  76. con.setRequestMethod("POST");
  77. con.setRequestProperty("Content-Type", "application/json");
  78. con.setDoOutput(true);
  79. ObjectMapper objectParmMapper = new ObjectMapper();
  80. // 创建一个Map结构表示您的数据
  81. Map<String, List<Map<String, String>>> dataMap = new HashMap<>();
  82. dataMap.put("data", sentenceList.stream()
  83. .map(sentence -> Collections.singletonMap("text", sentence))
  84. .collect(Collectors.toList()));
  85. String jsonData = null;
  86. try {
  87. // 将Map转换为JSON字符串
  88. jsonData = objectParmMapper.writeValueAsString(dataMap);
  89. } catch (JsonProcessingException e) {
  90. System.err.println("Error converting to JSON: " + e.getMessage());
  91. }
  92. String data = jsonData;
  93. try(OutputStream os = con.getOutputStream()) {
  94. byte[] input = data.getBytes("utf-8");
  95. os.write(input, 0, input.length);
  96. }
  97. int responseCode = con.getResponseCode();
  98. System.out.println("Response Code: " + responseCode);
  99. PaddleNewTextVo paddleNewTextVo = null;
  100. if (responseCode == HttpURLConnection.HTTP_OK) { // 200表示成功
  101. BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()));
  102. String inputLine;
  103. StringBuilder content = new StringBuilder();
  104. while ((inputLine = in.readLine()) != null) {
  105. content.append(inputLine);
  106. }
  107. in.close();
  108. try {
  109. String contentStr = content.toString();
  110. // 直接解析JSON字符串到PaddleTextVo实例
  111. paddleNewTextVo = JSON.parseObject(contentStr, PaddleNewTextVo.class);
  112. } catch (Exception e) {
  113. System.err.println("Error parsing JSON: " + e.getMessage());
  114. }
  115. } else {
  116. System.out.println("Error Response Code: " + responseCode);
  117. BufferedReader errorReader = new BufferedReader(new InputStreamReader(con.getErrorStream()));
  118. String errorMessage;
  119. while ((errorMessage = errorReader.readLine()) != null) {
  120. System.out.println("Error Message: " + errorMessage);
  121. }
  122. errorReader.close();
  123. }
  124. return paddleNewTextVo;
  125. }
  126. @Test
  127. void searchTest(){
  128. // 0.加载向量集合
  129. milvusService.loadCollection(Content.COLLECTION_NAME_NLP);
  130. try {
  131. List<String> sentenceList = new ArrayList<>();
  132. sentenceList.add("XX列表");
  133. String label = "操作直达";
  134. // 1.获得向量
  135. // List<List<Float>> vectors = getVectorsLists(sentenceList);
  136. List<List<Float>> vectors = new ArrayList<>();
  137. SearchNLPParamVo searchParamVo = SearchNLPParamVo.builder()
  138. .collectionName(Content.COLLECTION_NAME_NLP)
  139. .queryVectors(vectors)
  140. .expr("label == '" + label + "'")
  141. .topK(3)
  142. .build();
  143. // 2.在向量数据库中进行搜索内容知识
  144. List<List<SearchNLPResultVo>> lists = milvusService.searchNLPTopKSimilarity(searchParamVo);
  145. lists.forEach(searchResultVos -> {
  146. searchResultVos.forEach(searchResultVo -> {
  147. System.out.println(searchResultVo.getContent());
  148. System.out.println(searchResultVo.getContentAnswer());
  149. System.out.println(searchResultVo.getTitle());
  150. System.out.println(searchResultVo.getLabel());
  151. });
  152. });
  153. } catch (ApiException e) {
  154. System.out.println(e.getMessage());
  155. } /*catch (IOException e) {
  156. throw new RuntimeException(e);
  157. }
  158. */
  159. }
  160. @Test
  161. public void insertTextVector() throws IOException {
  162. List<String> titleList = new ArrayList<>();
  163. List<String> sentenceList = new ArrayList<>();
  164. List<String> contentAnswerList = new ArrayList<>();
  165. List<String> paramList = new ArrayList<>();
  166. List<String> typeList = new ArrayList<>();
  167. String filePath = "src/main/resources/data/text.txt";
  168. try (BufferedReader reader = new BufferedReader(
  169. new InputStreamReader(new FileInputStream(filePath), StandardCharsets.UTF_8))) {
  170. // 使用4个竖线(||||)作为分隔符
  171. String line;
  172. while ((line = reader.readLine()) != null) {
  173. String[] parts = line.split("\\|\\|\\|\\|");
  174. if (parts.length >= 3) {
  175. titleList.add(parts[0].trim());
  176. sentenceList.add(parts[1].trim());
  177. contentAnswerList.add(parts[2].trim());
  178. paramList.add("");
  179. typeList.add("2");
  180. } else {
  181. System.out.println("Warning: Invalid format on line: " + line);
  182. }
  183. }
  184. // 打印或处理列表内容
  185. System.out.println("Title List: " + titleList);
  186. System.out.println("Sentence List: " + sentenceList);
  187. System.out.println("Content Answer List: " + contentAnswerList);
  188. } catch (IOException e) {
  189. System.err.println("Error reading file: " + e.getMessage());
  190. }
  191. try {
  192. // 1.获得向量
  193. TextEmbeddingParam param = TextEmbeddingParam
  194. .builder()
  195. .model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
  196. .texts(sentenceList).build();
  197. TextEmbedding textEmbedding = new TextEmbedding();
  198. TextEmbeddingResult result = textEmbedding.call(param);
  199. List<List<Float>> vectors = new ArrayList<>();
  200. for (int i = 0; i < result.getOutput().getEmbeddings().size(); i++) {
  201. List<Double> vector = result.getOutput().getEmbeddings().get(i).getEmbedding();
  202. List<Float> floatVector = vector.stream()
  203. .map(Double::floatValue)
  204. .collect(Collectors.toList());
  205. vectors.add(floatVector);
  206. }
  207. // 2.准备插入向量数据库
  208. List<InsertParam.Field> fields = new ArrayList<>();
  209. fields.add(new InsertParam.Field(Content.Field.CONTENT, sentenceList));
  210. fields.add(new InsertParam.Field(Content.Field.CONTENT_VECTOR, vectors));
  211. fields.add(new InsertParam.Field(Content.Field.CONTENT_ANSWER, contentAnswerList));
  212. fields.add(new InsertParam.Field(Content.Field.TITLE, titleList));
  213. fields.add(new InsertParam.Field(Content.Field.PARAM, paramList));
  214. fields.add(new InsertParam.Field(Content.Field.TYPE, typeList));
  215. // 3.执行操作
  216. milvusService.insert(Content.COLLECTION_NAME_NLP, fields);
  217. } catch (ApiException | NoApiKeyException e) {
  218. System.out.println(e.getMessage());
  219. }
  220. }
  221. @Test
  222. void ChatBasedContentTest() throws NoApiKeyException, InputRequiredException, InterruptedException {
  223. // 0.加载向量集合
  224. milvusService.loadCollection(Content.COLLECTION_NAME_NLP);
  225. try {
  226. String question = "查询订单";
  227. List<String> sentenceList = new ArrayList<>();
  228. sentenceList.add(question);
  229. // 1.获得向量
  230. TextEmbeddingParam param = TextEmbeddingParam
  231. .builder()
  232. .model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
  233. .texts(sentenceList).build();
  234. TextEmbedding textEmbedding = new TextEmbedding();
  235. TextEmbeddingResult result = textEmbedding.call(param);
  236. List<Double> vector = result.getOutput().getEmbeddings().get(0).getEmbedding();
  237. List<Float> floatVector = vector.stream()
  238. .map(Double::floatValue)
  239. .collect(Collectors.toList());
  240. List<List<Float>> vectors = Collections.singletonList(floatVector);
  241. SearchERPParamVo searchParamVo = SearchERPParamVo.builder()
  242. .collectionName(Content.COLLECTION_NAME_NLP)
  243. .queryVectors(vectors)
  244. .topK(3)
  245. .build();
  246. // 2.在向量数据库中进行搜索内容知识
  247. StringBuffer buffer = new StringBuffer();
  248. List<List<SearchERPResultVo>> lists = milvusService.searchERPTopKSimilarity(searchParamVo);
  249. lists.forEach(searchResultVos -> {
  250. searchResultVos.forEach(searchResultVo -> {
  251. buffer.append("问题: " + searchResultVo.getContent());
  252. buffer.append("答案: " + searchResultVo.getContentAnswer());
  253. });
  254. });
  255. // 3.进行对话
  256. String prompt = "请你充分理解下面的内容,然后回答问题, 要求仅返回答案[]中内容:";
  257. String content = buffer.toString();
  258. String resultQwen = streamCallWithCallback(prompt + content + question);
  259. // System.out.println(resultQwen);
  260. } catch (ApiException | NoApiKeyException e) {
  261. System.out.println(e.getMessage());
  262. }
  263. }
  264. public static String streamCallWithCallback(String content)
  265. throws NoApiKeyException, ApiException, InputRequiredException,InterruptedException {
  266. Constants.apiKey="sk-2106098eed1f43c9bde754f3e87038a2";
  267. Generation gen = new Generation();
  268. Message userMsg = Message
  269. .builder()
  270. .role(Role.USER.getValue())
  271. .content(content)
  272. .build();
  273. QwenParam param = QwenParam
  274. .builder()
  275. .model(Generation.Models.QWEN_PLUS)
  276. .resultFormat(QwenParam.ResultFormat.MESSAGE)
  277. .messages(Arrays.asList(userMsg))
  278. .topP(0.8)
  279. .incrementalOutput(true) // get streaming output incrementally
  280. .build();
  281. Semaphore semaphore = new Semaphore(0);
  282. StringBuilder fullContent = new StringBuilder();
  283. gen.streamCall(param, new ResultCallback<GenerationResult>() {
  284. @Override
  285. public void onEvent(GenerationResult message) {
  286. fullContent.append(message.getOutput().getChoices().get(0).getMessage().getContent());
  287. System.out.println(message);
  288. }
  289. @Override
  290. public void onError(Exception err){
  291. System.out.println(String.format("Exception: %s", err.getMessage()));
  292. semaphore.release();
  293. }
  294. @Override
  295. public void onComplete(){
  296. System.out.println("Completed");
  297. semaphore.release();
  298. }
  299. });
  300. semaphore.acquire();
  301. System.out.println("Full content: \n" + fullContent.toString());
  302. return fullContent.toString();
  303. }
  304. @Test
  305. void loadData() throws IOException {
  306. // Read the dataset file
  307. String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
  308. // Load dataset
  309. JSONObject dataset = JSON.parseObject(content);
  310. List<JSONObject> rows = getRows(dataset.getJSONArray("rows"), 2);
  311. System.out.println(rows);
  312. }
  313. public String readFileToString(String filePath) throws IOException {
  314. return new String(Files.readAllBytes(Paths.get(filePath)), StandardCharsets.UTF_8);
  315. }
  316. public static List<JSONObject> getRows(JSONArray dataset, int counts) {
  317. List<JSONObject> rows = new ArrayList<>();
  318. for (int i = 0; i < counts; i++) {
  319. JSONObject row = dataset.getJSONObject(i);
  320. List<Float> vectors = row.getJSONArray("title_vector").toJavaList(Float.class);
  321. Long reading_time = row.getLong("reading_time");
  322. Long claps = row.getLong("claps");
  323. Long responses = row.getLong("responses");
  324. row.put("title_vector", vectors);
  325. row.put("reading_time", reading_time);
  326. row.put("claps", claps);
  327. row.put("responses", responses);
  328. row.remove("id");
  329. rows.add(row);
  330. }
  331. return rows;
  332. }
  333. @Test
  334. void getFileds() throws IOException {
  335. String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
  336. // Load dataset
  337. JSONObject dataset = JSON.parseObject(content);
  338. List<InsertParam.Field> field = getFields(dataset.getJSONArray("rows"), 1);
  339. System.out.println(field);
  340. }
  341. public static List<InsertParam.Field> getFields(JSONArray dataset, int counts) {
  342. List<InsertParam.Field> fields = new ArrayList<>();
  343. List<String> titles = new ArrayList<>();
  344. List<List<Float>> title_vectors = new ArrayList<>();
  345. List<String> links = new ArrayList<>();
  346. List<Long> reading_times = new ArrayList<>();
  347. List<String> publications = new ArrayList<>();
  348. List<Long> claps_list = new ArrayList<>();
  349. List<Long> responses_list = new ArrayList<>();
  350. for (int i = 0; i < counts; i++) {
  351. JSONObject row = dataset.getJSONObject(i);
  352. titles.add(row.getString("title"));
  353. title_vectors.add(row.getJSONArray("title_vector").toJavaList(Float.class));
  354. links.add(row.getString("link"));
  355. reading_times.add(row.getLong("reading_time"));
  356. publications.add(row.getString("publication"));
  357. claps_list.add(row.getLong("claps"));
  358. responses_list.add(row.getLong("responses"));
  359. }
  360. fields.add(new InsertParam.Field("title", titles));
  361. fields.add(new InsertParam.Field("title_vector", title_vectors));
  362. fields.add(new InsertParam.Field("link", links));
  363. fields.add(new InsertParam.Field("reading_time", reading_times));
  364. fields.add(new InsertParam.Field("publication", publications));
  365. fields.add(new InsertParam.Field("claps", claps_list));
  366. fields.add(new InsertParam.Field("responses", responses_list));
  367. return fields;
  368. }
  369. @Test
  370. void searchTopKSimilarity() throws IOException {
  371. // Search data
  372. String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
  373. // Load dataset
  374. JSONObject dataset = JSON.parseObject(content);
  375. List<JSONObject> rows = getRows(dataset.getJSONArray("rows"), 10);
  376. // You should include the following in the main function
  377. List<List<Float>> queryVectors = new ArrayList<>();
  378. List<Float> queryVector = rows.get(0).getJSONArray("title_vector").toJavaList(Float.class);
  379. queryVectors.add(queryVector);
  380. // Prepare the outputFields
  381. List<String> outputFields = new ArrayList<>();
  382. outputFields.add("title");
  383. outputFields.add("link");
  384. // Search vectors in a collection
  385. SearchParam searchParam = SearchParam.newBuilder()
  386. .withCollectionName("medium_articles")
  387. .withVectorFieldName("title_vector")
  388. .withVectors(queryVectors)
  389. .withExpr("claps > 30 and reading_time < 10")
  390. .withTopK(3)
  391. .withMetricType(MetricType.L2)
  392. .withParams("{\"nprobe\":10,\"offset\":2, \"limit\":3}")
  393. .withConsistencyLevel(ConsistencyLevelEnum.BOUNDED)
  394. .withOutFields(outputFields)
  395. .build();
  396. R<SearchResults> response = milvusClient.search(searchParam);
  397. SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
  398. System.out.println("Search results");
  399. for (int i = 0; i < queryVectors.size(); ++i) {
  400. List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
  401. List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
  402. for (int j = 0; j < scores.size(); ++j) {
  403. SearchResultsWrapper.IDScore score = scores.get(j);
  404. QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
  405. System.out.println("Top " + j + " ID:" + score.getLongID() + " Distance:" + score.getScore());
  406. System.out.println("Title: " + rowRecord.get("title"));
  407. System.out.println("Link: " + rowRecord.get("link"));
  408. }
  409. }
  410. }
  411. }

4、查询

// 先根据向量查询语义相近的语料
List<Question> questionList = mivusService.searchNewPaddleQuestion(req.getMessage(), "1", appType);

  1. **
  2. * 根据问题进行向量查询,采用Paddle服务 采用新的文本分类方法
  3. * @param question 用户的问题文本
  4. * @return 相关的初始问题知识列表
  5. */
  6. public List<Question> searchNewPaddleQuestion(String question, String type, String appType) {
  7. // 0.加载向量集合
  8. String collection = Content.COLLECTION_NAME_NLP;
  9. if (appType.equals("1")) {
  10. collection = Content.COLLECTION_NAME_NLP_APP;
  11. }
  12. loadCollection(collection);
  13. List<Question> resultList = new LinkedList<>();
  14. PaddleNewTextVo paddleNewTextVo = null;
  15. try {
  16. List<String> sentenceList = new ArrayList<>();
  17. sentenceList.add(question);
  18. // 1.获得向量
  19. paddleNewTextVo = getNewNLPVectorsLists(sentenceList);
  20. log.info("实时向量值 : {}", paddleNewTextVo.getPredictedList());
  21. List<List<Double>> vectors = paddleNewTextVo.getVector();
  22. List<List<Float>> floatVectors = new ArrayList<>();
  23. for (List<Double> innerList : vectors) {
  24. List<Float> floatInnerList = new ArrayList<>();
  25. for (Double value : innerList) {
  26. floatInnerList.add(value.floatValue());
  27. }
  28. floatVectors.add(floatInnerList);
  29. }
  30. List<Integer> predictedList = paddleNewTextVo.getPredictedList();
  31. List<String> labelStrings = new ArrayList<>();
  32. HashSet<Integer> setType = new HashSet();
  33. int topK = 3;
  34. if(!predictedList.isEmpty()) {
  35. // 去重
  36. for (Integer number : predictedList) {
  37. setType.add(number);
  38. if (number == 2) {
  39. // 如何是 2
  40. topK = 1;
  41. }
  42. }
  43. for (Integer label : setType) {
  44. labelStrings.add("'" + label + "'");
  45. }
  46. }
  47. String typeResult = "[" + String.join(", ", labelStrings) + "]";
  48. SearchNLPParamVo searchParamVo = SearchNLPParamVo.builder()
  49. .collectionName(collection)
  50. //.expr("type == '" + type + "'")
  51. .expr("type in ['0','1','2']")
  52. //.expr("type in " + typeResult + " ")
  53. .queryVectors(floatVectors)
  54. .topK(topK)
  55. .build();
  56. // 2.在向量数据库中进行搜索内容知识
  57. List<List<SearchNLPResultVo>> lists = searchNLPERPTopKSimilarity(searchParamVo);
  58. lists.forEach(searchResultVos -> {
  59. searchResultVos.forEach(searchResultVo -> {
  60. log.info(searchResultVo.getContent());
  61. log.info(searchResultVo.getContentAnswer());
  62. Question question1 = new Question();
  63. question1.setQuestionId(Long.valueOf(searchResultVo.getId()));
  64. question1.setQuestion(searchResultVo.getContent());
  65. question1.setAnswer(searchResultVo.getContentAnswer());
  66. question1.setTitle(searchResultVo.getTitle());
  67. question1.setParam(searchResultVo.getParam());
  68. question1.setType(searchResultVo.getType());
  69. question1.setLabel(searchResultVo.getLabel());
  70. resultList.add(question1);
  71. });
  72. });
  73. } catch (ApiException | IOException e) {
  74. log.error(e.getMessage());
  75. }
  76. // 将查询到的结果转换为之前构造的 Question 的格式返回给前端
  77. return resultList;
  78. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/823402
推荐阅读
相关标签
  

闽ICP备14008679号