赞
踩
public class PageQuery implements Serializable { private static final long serialVersionUID = 7172912761241281958L; /** * 当前页 */ private Integer page = 0; /** * 条目数 */ private Integer size = 20; /** * 关键字 */ @ApiModelProperty(value = "搜索关键字") private String keyword; /** 排序字段 */ @ApiModelProperty(value = "排序字段") private String sortField; /** 排序方法 */ @ApiModelProperty(value = "排序方式 asc,desc") private String sortWay; public Integer getPage() { return page; } public void setPage(Integer page) { this.page = page; } public Integer getSize() { return size; } public void setSize(Integer size) { this.size = size; } public String getKeyword() { return keyword; } public void setKeyword(String keyword) { this.keyword = keyword; } public String getSortField() { return sortField; } public void setSortField(String sortField) { this.sortField = sortField; } public String getSortWay() { return sortWay; } public void setSortWay(String sortWay) { this.sortWay = sortWay; }
import org.elasticsearch.search.sort.SortOrder; import java.io.Serializable; public class SortParam implements Serializable { private static final long serialVersionUID = -379151600753725891L; /** 排序字段 */ private String fieldName; /** 排序方式 */ private SortOrder order; public String getFieldName() { return fieldName; } public void setFieldName(String fieldName) { this.fieldName = fieldName; } public SortOrder getOrder() { return order; } public void setOrder(SortOrder order) { this.order = order; } }
直接继承org.springframework.data.elasticsearch.repository.ElasticsearchRepository<T, ID>即可
public interface XxxRepository extends org.springframework.data.elasticsearch.repository.ElasticsearchRepository<Xxx, String>{
}
import cn.venny.base.beans.PageQuery; import cn.venny.base.beans.SortParam; import cn.venny.base.utils.CollectionUtils; import cn.venny.base.utils.StringUtils; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.search.sort.SortOrder; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.lang.Nullable; import java.lang.reflect.Field; import java.lang.reflect.ParameterizedType; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; public interface IEsBaseService<T, ID> { /** 属性缓存 */ Map<Class<?>, List<Field>> FIELD_CACHE = new ConcurrentHashMap<>(); /** * 获取实体class类型 * @return 实体clazz类型 */ default Class<T> getEntityClass() { return (Class<T>) (((ParameterizedType) this.getClass().getGenericSuperclass()).getActualTypeArguments()[0]); } /** * 获取实体所有属性(排除序列号属性) * @return 属性集合 */ default List<Field> getEntityAllField() { Class<T> entityClass = getEntityClass(); if (FIELD_CACHE.get(entityClass) != null) { return FIELD_CACHE.get(entityClass); } Field[] currentFields = entityClass.getDeclaredFields(); Class<? super T> superclass = entityClass.getSuperclass(); List<Field> supperFields = new ArrayList<>(); // 可能有多层继承 while (!superclass.equals(Object.class)) { Field[] declaredFields = superclass.getDeclaredFields(); Collections.addAll(supperFields, declaredFields); superclass = superclass.getSuperclass(); } // 排除序列化字段 List<Field> fieldList = Arrays.stream(currentFields).filter(f -> !"serialVersionUID".equalsIgnoreCase(f.getName())).distinct().collect(Collectors.toList()); // 父类字段 List<Field> superFieldList = supperFields.stream().filter(f -> !"serialVersionUID".equalsIgnoreCase(f.getName())).distinct().collect(Collectors.toList()); if (CollectionUtils.notEmpty(superFieldList)) { fieldList.addAll(superFieldList); } FIELD_CACHE.put(entityClass, fieldList); return fieldList; } /** * 页面返回字段 * @return 实体所有字段名称数组 */ default String[] returnFields() { List<Field> fieldList = getEntityAllField(); String[] fields = new String[fieldList.size()]; for (int i = 0; i < fieldList.size(); i++) { fields[i] = fieldList.get(i).getName(); } return fields; } /** * 排序字段、排序方式设置 * 默认按照创建时间倒叙排列 * @param query 查询参数 * @param <Q> 查询参数实体 */ default <Q extends PageQuery> List<SortParam> sortFields(Q query) { if (StringUtils.isEmpty(query.getSortField()) || StringUtils.isEmpty(query.getSortWay())) { return null; } SortParam sp = new SortParam(); sp.setFieldName(query.getSortField()); sp.setOrder(SortOrder.valueOf(query.getSortWay().toUpperCase())); return CollectionUtils.singleList(sp); } /** * 构建过滤条件 */ default <Q extends PageQuery> void buildFilterCondition(BoolQueryBuilder filter, Q queryParam) { // eg: // 带分词匹配 // filter.must(QueryBuilders.matchQuery("xxx", query.getXxxx())); // 不分词匹配 // filter.must(QueryBuilders.termQuery("xxx", query.getXxx())); // 范围匹配 // filter.must(QueryBuilders.rangeQuery("createTime").gte(query.getCreateTime() + " 00:00:00")); } <S extends T> S save(S entity); <S extends T> Iterable<S> saveAll(Iterable<S> entities); Optional<T> findById(ID id); boolean existsById(ID id); Collection<T> findAll(); Collection<T> findAllById(Collection<ID> ids); long count(); void deleteById(ID id); void delete(T entity); void deleteAllById(Iterable<? extends ID> ids); void deleteAll(Collection<? extends T> entities); void deleteAll(); Iterable<T> findAll(Sort sort); Page<T> findAll(Pageable pageable); /** * 模糊搜索 * @param entity 请求实体 * @param fields 查询字段名称 * @param pageable 分页对象 * @return 分页参数 */ Page<T> searchSimilar(T entity, @Nullable String[] fields, Pageable pageable); /** * 分页查询(自定义) * @param query 请求参数 * @param <Q> 请求参数类型 * @return 分页数据 */ <Q extends PageQuery> Page<T> search(Q query); /** * 分页查询(自定义) * @param query 请求参数 * @param <Q> 请求参数类型 * @return 数据总数 */ <Q extends PageQuery> Long count(Q query); /** * 分页查询(自定义) * @param query 请求参数 * @param <Q> 请求参数类型 * @return 数据总数 */ <Q extends PageQuery> List<T> list(Q query); /** * 分页查询(自定义) * @param query 请求参数 * @param <Q> 请求参数类型 * @param columnName 返回列名 * @return 数据总数 */ <Q extends PageQuery> List<T> list(Q query, String... columnName); /** * 根据ID保存或者更新 * @param entity 请求实体 * @param <S> 实体类型 */ <S extends T> void update(S entity); /** * 根据ID保存或者更新 * @param entity 请求实体 * @param <S> 实体类型 */ <S extends T> void updateAndFlush(S entity); /** * 批量保存或者更新 * @param entities 请求实体集合 * @param <S> 实体类型 */ <S extends T> void update(Collection<S> entities); /** * 批量保存或者更新 * @param entities 请求实体集合 * @param <S> 实体类型 */ <S extends T> void updateAndFlush(Collection<S> entities); /** * 批量保存或者更新 * @param entities 请求实体集合 * @param <S> 实体类型 */ <S extends T> void saveOrUpdate(Collection<S> entities);
import xx.xx.xx.PageQuery; import xx.xx.xx.SortParam; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.sort.SortBuilders; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.annotation.Id; import org.springframework.data.domain.*; import org.springframework.data.elasticsearch.core.ElasticsearchRestTemplate; import org.springframework.data.elasticsearch.core.SearchHit; import org.springframework.data.elasticsearch.core.SearchHits; import org.springframework.data.elasticsearch.core.document.Document; import org.springframework.data.elasticsearch.core.query.BulkOptions; import org.springframework.data.elasticsearch.core.query.NativeSearchQuery; import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder; import org.springframework.data.elasticsearch.core.query.UpdateQuery; import java.lang.reflect.Field; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; public abstract class EsBaseServiceImpl<T, ID, M extends org.springframework.data.elasticsearch.repository.ElasticsearchRepository<T, ID>> implements IEsBaseService<T, ID> { @Autowired(required = false) public M repository; @Autowired public ElasticsearchRestTemplate elasticsearchRestTemplate; @Override public <S extends T> S save(S entity) { return repository.save(entity); } @Override public <S extends T> Iterable<S> saveAll(Iterable<S> entities) { return repository.saveAll(entities); } @Override public Optional<T> findById(ID id) { return repository.findById(id); } @Override public boolean existsById(ID id) { return repository.existsById(id); } @Override public Collection<T> findAll() { return list(new PageQuery()); } @Override public Collection<T> findAllById(Collection<ID> ids) { return (Collection<T>) repository.findAllById(ids); } @Override public long count() { return repository.count(); } @Override public void deleteById(ID id) { repository.deleteById(id); } @Override public void delete(T entity) { final List<Field> fields = getEntityAllField(); AtomicInteger num = new AtomicInteger(); // 构建过滤条件 BoolQueryBuilder filter = buildFilterBoolQueryBuilder(fields, entity, num); if (num.intValue() < 1) { return; } // 构建查询条件 NativeSearchQueryBuilder queryBuilder = new NativeSearchQueryBuilder(); queryBuilder.withFilter(filter); // 执行删除 elasticsearchRestTemplate.delete(queryBuilder.build(), getEntityClass()); } @Override public void deleteAllById(Iterable<? extends ID> ids) { repository.deleteAllById(ids); } @Override public void deleteAll(Collection<? extends T> entities) { if (CollectionUtils.isEmpty(entities)) { return; } entities.forEach(this::delete); } @Override public void deleteAll() { repository.deleteAll(); } @Override public Iterable<T> findAll(Sort sort) { return repository.findAll(sort); } @Override public Page<T> findAll(Pageable pageable) { return repository.findAll(pageable); } @Override public Page<T> searchSimilar(T entity, String[] fields, Pageable pageable) { return repository.searchSimilar(entity, fields, pageable); } @Override public <Q extends PageQuery> Page<T> search(Q query) { Long total = count(query); SearchHits<T> searchHits = commonSearch(query, true); if (searchHits.getTotalHits() > 0) { List<T> searchProductList = searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList()); return new PageImpl<>(searchProductList, PageRequest.of(query.getPage(), query.getSize()), total); } return new PageImpl<T>(new ArrayList<>(), PageRequest.of(query.getPage(), query.getSize()), total); } @Override public <Q extends PageQuery> Long count(Q query) { return commonSearch(query, false).getTotalHits(); } @Override public <Q extends PageQuery> List<T> list(Q query) { SearchHits<T> searchHits = commonSearch(query, false); if (searchHits.getTotalHits() > 0) { return searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList()); } return null; } @Override public <Q extends PageQuery> List<T> list(Q query, String... columnName) { SearchHits<T> searchHits = commonSearch(query, false, columnName); if (searchHits.getTotalHits() > 0) { return searchHits.stream().map(SearchHit::getContent).collect(Collectors.toList()); } return null; } @Override public <S extends T> void update(S entity) { commonUpdate(CollectionUtils.singleList(entity), false); } @Override public <S extends T> void updateAndFlush(S entity) { commonUpdate(CollectionUtils.singleList(entity), true); } @Override public <S extends T> void update(Collection<S> entities) { commonUpdate(entities, false); } @Override public <S extends T> void updateAndFlush(Collection<S> entities) { commonUpdate(entities, true); } @Override public <S extends T> void saveOrUpdate(Collection<S> entities) { Map<ID, Map<String, Object>> tempMap = idTempMap(entities); if (tempMap == null) { return; } List<ID> ids = new ArrayList<>(tempMap.keySet()); Collection<T> records = findAllById(ids); if (CollectionUtils.isEmpty(records)) { saveAll(entities); return; } List<T> save = new CopyOnWriteArrayList<>(); List<T> update = new CopyOnWriteArrayList<>(); records.forEach(entity -> { Field[] declaredFields = entity.getClass().getDeclaredFields(); for (Field field : declaredFields) { if (!field.isAnnotationPresent(Id.class)) { continue; } ID id = (ID) doGetFieldValue(field, entity); Map<String, Object> map = tempMap.get(id); if (map == null) { // save save.add(entity); } else { // update update.add(entity); } } }); if (CollectionUtils.notEmpty(save)) { saveAll(save); } if (CollectionUtils.notEmpty(update)) { commonUpdate(update, true); } } /** * 通用更新 * @param entities 请求实体 * @param flush 是否立即刷新 * @param <S> 请求实体类型 */ private <S extends T> void commonUpdate(Collection<S> entities, Boolean flush) { Map<ID, Map<String, Object>> tempMap = idTempMap(entities); if (tempMap == null) { return; } List<UpdateQuery> queries = new CopyOnWriteArrayList<>(); tempMap.forEach((id, params) -> { UpdateQuery build = UpdateQuery.builder(String.valueOf(id)) .withDocument(Document.from(params)) .build(); queries.add(build); }); if (flush) { // 立刻刷新,损害性能 elasticsearchRestTemplate.bulkUpdate( queries, BulkOptions.builder().withRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).build(), elasticsearchRestTemplate.getIndexCoordinatesFor(getEntityClass())); } else { // 不执行立刻刷新,损害性能 elasticsearchRestTemplate.bulkUpdate(queries, getEntityClass()); } } private <S extends T> Map<ID, Map<String, Object>> idTempMap(Collection<S> entities) { if (CollectionUtils.isEmpty(entities)) { return null; } final List<Field> fields = getEntityAllField(); Map<ID, Map<String, Object>> tempMap = new ConcurrentHashMap<>(); entities.forEach(entity -> buildIdMapParams(fields, entity, tempMap)); if (CollectionUtils.isEmpty(tempMap)) { return null; } return tempMap; } private <S extends T> void buildIdMapParams(List<Field> fields, S entity, Map<ID, Map<String, Object>> tempMap) { // 用来存放参数 Map<String, Object> params = new LinkedHashMap<>(); for (Field field : fields) { Object o = doGetFieldValue(field, entity); if (o == null) { continue; } params.put(field.getName(), o); if (field.isAnnotationPresent(Id.class)) { // 主键ID tempMap.put((ID) o, params); } } } private BoolQueryBuilder buildFilterBoolQueryBuilder(List<Field> fields, T entity, AtomicInteger num) { // 查询构建器 BoolQueryBuilder filter = QueryBuilders.boolQuery(); for (Field field : fields) { Object obj = doGetFieldValue(field, entity); if (obj == null) { continue; } // 计数器统计数量+1 num.incrementAndGet(); filter.must(QueryBuilders.termQuery(field.getName(), obj)); } return filter; } /** * 获取属性值 * @param field field对象 * @param entity 实体类 * @return 属性值 */ private Object doGetFieldValue(Field field, T entity) { field.setAccessible(true); // 一般属性 Object o = null; try { o = field.get(entity); } catch (IllegalAccessException e) { log.error("获取属性异常", e); } return o; } /** * 通用查询 * @param query 查询条件 * @param <Q> 查询条件类型 * @param page 是否需要分页 * @return es响应对象 */ private <Q extends PageQuery> SearchHits<T> commonSearch(Q query, boolean page, String... columnName) { // 构建查询条件 NativeSearchQueryBuilder queryBuilder = new NativeSearchQueryBuilder(); // 查询构建器 BoolQueryBuilder builder = QueryBuilders.boolQuery(); // 构建过滤条件 buildFilterCondition(builder, query); queryBuilder.withQuery(builder); List<SortParam> sorts = sortFields(query); if (CollectionUtils.notEmpty(sorts)) { for (SortParam sort : sorts) { queryBuilder.withSort(SortBuilders.fieldSort(sort.getFieldName()).order(sort.getOrder())); } } // 分页条件 if (page) { queryBuilder.withPageable(PageRequest.of(query.getPage(), query.getSize())); } NativeSearchQuery nativeSearchQuery = queryBuilder.build(); // 页面返回字段设置 if (columnName != null && columnName.length > 0) { nativeSearchQuery.addFields(columnName); } else { nativeSearchQuery.addFields(returnFields()); } // 使用ElasticsearchRestTemplate进行复杂查询 return elasticsearchRestTemplate.search(nativeSearchQuery, this.getEntityClass()); } }
以上用到的工具类,StringUtils、CollectionUtils是自定义的工具类,具体实现很简单,继承spring对应的工具类,添加常用方法,例如:notEmpty()-->调用spring的isEmpty()方法再取反
/**
* 构建过滤条件
*/
public <Q extends PageQuery> void buildFilterCondition(BoolQueryBuilder filter, Q queryParam) {
// 强转为实际请求对象
XxxQuery query = (XxxQuery)queryParam;
// 根据实际参数构造查询条件
// eg:
// 带分词匹配
// filter.must(QueryBuilders.matchQuery("xxx", query.getXxx()));
// 不分词匹配
// filter.must(QueryBuilders.termQuery("xxx", query.getXxx()));
// 范围匹配
// filter.must(QueryBuilders.rangeQuery("createTime").gte(query.getCreateTime() + " 00:00:00"));
}
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。