赞
踩
MyBatis 是现在比较流行的 ORM 框架,得益于它简单易用,虽然增加了开发者的一些操作,但是带来了设计和使用上的灵活,得到广泛的使用。之前的一篇文章MyBatis 初始化之XML解析详解中我们已经知道了 MyBatis 的加载 XML 配置文件和加载 mappers 配置文件的流程,最终都是封装到了 Configuration 对象中。而MyBatis 插件功能也是 MyBatis 的一块重要功能。我们知道其可以在 DAO 层进行拦截,如实现分页、SQL语句执行的性能监控、公共字段统一赋值等功能。但对其内部实现机制,涉及的软件设计模式,编程思想往往没有深入的理解。本文对 MyBatis 加载 XML 配置文件中 plugins 插件流程以及对 MyBatis 插件实现原理进行深入分析,并实现自己的简单分页插件。
MyBatis 插件主要用到的类再 org.apache.ibatis.plugin 包下,如下图所示:
从命名可以看到,叫 MyBatis 拦截器可能更合适些,实际上它就是一个拦截器,使用 JDK 动态代理方式,实现在方法级别上进行拦截。支持拦截的方法有以下几种:
插件配置信息也是配置在 XML 配置文件中,所以我们直接从 XMLConfigBuilder 的 parseConfiguration 方法开始,可以看到 pluginElement(root.evalNode("plugins")) 用于解析 XML配置中的 plugins 节点标签。而 pluginElement 方法实现的主要功能就是遍历 plugins 节点中的 plugin 子节点,获取配置的 interceptor 属性为具体的插件实现,并将其添加到 Configuration 中的 interceptorChain 中。interceptorChain 是拦截器链,将拦截器存在 interceptors (是一个 Interceptor 数组)中。
private void parseConfiguration(XNode root) { try { ... // 解析节点 pluginElement(root.evalNode("plugins")); ... // 解析节点 mapperElement(root.evalNode("mappers")); } catch (Exception e) { throw new BuilderException("Error parsing SQL Mapper Configuration. Cause: " + e, e); }}private void pluginElement(XNode parent) throws Exception { if (parent != null) { for (XNode child : parent.getChildren()) { String interceptor = child.getStringAttribute("interceptor"); Properties properties = child.getChildrenAsProperties(); Interceptor interceptorInstance = (Interceptor) resolveClass(interceptor).getDeclaredConstructor().newInstance(); interceptorInstance.setProperties(properties); configuration.addInterceptor(interceptorInstance); } }}
至此,插件加载流程结束,相对比较简单,具体流程图如下图所示:
我们说 MyBatis 插件的实现机制主要是基于 JDK 动态代理实现的。那就有必要来了解下这些代理对象是如何生成的。MyBatis 插件机制是拦截执行器 Executor、参数处理器 ParameterHandler、结果集处理器 ResultSetHandler、SQL 语法构建器 StatementHandler 的,那我们就先看看,这些对象的创建。这些对象的创建在 Configuration 类中,源码如下所示:
public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) { ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql); parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler); return parameterHandler;}public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler, ResultHandler resultHandler, BoundSql boundSql) { ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds); resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler); return resultSetHandler;}public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql); statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler); return statementHandler;}public Executor newExecutor(Transaction transaction) { return newExecutor(transaction, defaultExecutorType);}public Executor newExecutor(Transaction transaction, ExecutorType executorType) { executorType = executorType == null ? defaultExecutorType : executorType; executorType = executorType == null ? ExecutorType.SIMPLE : executorType; Executor executor; if (ExecutorType.BATCH == executorType) { executor = new BatchExecutor(this, transaction); } else if (ExecutorType.REUSE == executorType) { executor = new ReuseExecutor(this, transaction); } else { executor = new SimpleExecutor(this, transaction); } if (cacheEnabled) { executor = new CachingExecutor(executor); } executor = (Executor) interceptorChain.pluginAll(executor); return executor;}
观察源码,发现这些可拦截的类对应的对象生成都是通过 InterceptorChain 的 pluginAll 方法来创建的,进一步观察 pluginAll 方法,如下:
public class InterceptorChain { private final List interceptors = new ArrayList<>(); public Object pluginAll(Object target) { for (Interceptor interceptor : interceptors) { target = interceptor.plugin(target); } return target; } public void addInterceptor(Interceptor interceptor) { interceptors.add(interceptor); } public List getInterceptors() { return Collections.unmodifiableList(interceptors); }}
我们之前知道了配置的拦截器都是放进了 interceptorChain 的 interceptors 中,而 pluginAll 方法中遍历所有拦截器,并调用拦截器的 plugin 方法生成代理对象,注意生成代理对象重新赋值给 target,这里需要注意的是如果有多个拦截器的话,生成的代理对象会被另一个代理对象代理,从而形成一个代理链,执行的时候,依次执行所有拦截器的拦截逻辑代码。如果 target 对象不是某个拦截器关注的,我们可以在自己实现的拦截器中的 plugin 方法进行判断,如果需要则使用Plugin#wrap 方法创建代理对象。代码如下所示:
public class Plugin implements InvocationHandler { private final Object target; private final Interceptor interceptor; private final Map, Set> signatureMap; private Plugin(Object target, Interceptor interceptor, Map, Set> signatureMap) { this.target = target; this.interceptor = interceptor; this.signatureMap = signatureMap; } public static Object wrap(Object target, Interceptor interceptor) { Map, Set> signatureMap = getSignatureMap(interceptor); Class> type = target.getClass(); Class>[] interfaces = getAllInterfaces(type, signatureMap); if (interfaces.length > 0) { return Proxy.newProxyInstance( type.getClassLoader(), interfaces, new Plugin(target, interceptor, signatureMap)); } return target; } ... }
可以看到 Plugin 类实现了 InvocationHandler,使用 JDK 动态代理, 而Plugin#wrap 方法通过获取 signatureMap 来选择需要拦截的方法。
Plugin 类实现了 InvocationHandler 接口,真正去执行 Executor、ParameterHandler、ResultSetHandler 和 StatementHandler 类中的方法的对象是代理对象,所以在执行方法时,首先调用的是 Plugin 类的 invoke 方法,如下:
@Overridepublic Object invoke(Object proxy, Method method, Object[] args) throws Throwable { try { Set methods = signatureMap.get(method.getDeclaringClass()); if (methods != null && methods.contains(method)) { return interceptor.intercept(new Invocation(target, method, args)); } return method.invoke(target, args); } catch (Exception e) { throw ExceptionUtil.unwrapThrowable(e); }}
首先从 signatureMap 从获取需要拦截的方法集合,这是在我们自己实现的拦截器中通过 Signature 注解声明的。判断当前方法需不需要执行拦截逻辑,需要的话,执行拦截逻辑方法(即 Interceptor 接口的 intercept 方法实现),不需要则直接执行原方法。这里用到了 Invocation 类,它封装了代理对象,拦截方法和参数,后面在实现自己的拦截器的时候需要使用。
public class Invocation { private final Object target; private final Method method; private final Object[] args; public Invocation(Object target, Method method, Object[] args) { this.target = target; this.method = method; this.args = args; } public Object getTarget() { return target; } public Method getMethod() { return method; } public Object[] getArgs() { return args; } public Object proceed() throws InvocationTargetException, IllegalAccessException { return method.invoke(target, args); }}
我们通过实现一个简单分页插件,来熟悉 MyBatis 插件的编写规则,配置如下所示:
代码实现主要需要实现 Interceptor 接口,其中主要有三个方法:
@Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class}),})public class PageInterceptor implements Interceptor { /** * 默认页码 */ private Integer defaultPageIndex; /** * 默认每页数据条数 */ private Integer defaultPageSize; @Override public Object intercept(Invocation invocation) throws Throwable { StatementHandler statementHandler = getUnProxyObject(invocation); MetaObject metaObject = SystemMetaObject.forObject(statementHandler); String sql = getSql(metaObject); if (!checkSelect(sql)) { // 不是select语句,进入责任链下一层 return invocation.proceed(); } BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql"); Object parameterObject = boundSql.getParameterObject(); Page page = getPage(parameterObject); if (page == null) { // 没有传入page对象,不执行分页处理,进入责任链下一层 return invocation.proceed(); } // 设置分页默认值 if (page.getPageNum() == null) { page.setPageNum(this.defaultPageIndex); } if (page.getPageSize() == null) { page.setPageSize(this.defaultPageSize); } // 设置分页总数,数据总数 setTotalToPage(page, invocation, metaObject, boundSql); // 校验分页参数 checkPage(page); return changeSql(invocation, metaObject, boundSql, page); } @Override public Object plugin(Object target) { // 生成代理对象 return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) { // 初始化配置的默认页码,无配置则默认1 this.defaultPageIndex = Integer.parseInt(properties.getProperty("defaultPageIndex", "1")); // 初始化配置的默认数据条数,无配置则默认20 this.defaultPageSize = Integer.parseInt(properties.getProperty("defaultPageSize", "20")); } /** * 从代理对象中分离出真实对象 * * @param invocation * @return */ private StatementHandler getUnProxyObject(Invocation invocation) { // 取出被拦截的对象 StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); MetaObject metaStmtHandler = SystemMetaObject.forObject(statementHandler); Object object = null; // 分离代理对象 while (metaStmtHandler.hasGetter("h")) { object = metaStmtHandler.getValue("h"); metaStmtHandler = SystemMetaObject.forObject(object); } return object == null ? statementHandler : (StatementHandler) object; } /** * 判断是否是select语句 * * @param sql * @return */ private boolean checkSelect(String sql) { // 去除sql的前后空格,并将sql转换成小写 sql = sql.trim().toLowerCase(); return sql.indexOf("select") == 0; } /** * 获取分页参数 * * @param parameterObject * @return */ private Page getPage(Object parameterObject) { if (parameterObject == null) { return null; } if (parameterObject instanceof Map) { // 如果传入的参数是map类型的,则遍历map取出Page对象 Map parameMap = (Map) parameterObject; Set keySet = parameMap.keySet(); for (String key : keySet) { Object value = parameMap.get(key); if (value instanceof Page) { // 返回Page对象 return (Page) value; } } } else if (parameterObject instanceof Page) { // 如果传入的是Page类型,则直接返回该对象 return (Page) parameterObject; } // 初步判断并没有传入Page类型的参数,返回null return null; } /** * 获取数据总数 * * @param invocation * @param metaObject * @param boundSql * @return */ private Long getTotal(Invocation invocation, MetaObject metaObject, BoundSql boundSql) { // 获取当前的mappedStatement对象 MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); // 获取配置对象 Configuration configuration = mappedStatement.getConfiguration(); // 获取当前需要执行的sql String sql = getSql(metaObject); // 改写sql语句,实现返回数据总数 $_paging取名是为了防止数据库表重名 String countSql = "select count(*) as total from (" + sql + ") $_paging"; // 获取拦截方法参数,拦截的是connection对象 Connection connection = (Connection) invocation.getArgs()[0]; PreparedStatement pstmt = null; Long total = 0L; try { // 预编译查询数据总数的sql语句 pstmt = connection.prepareStatement(countSql); // 构建boundSql对象 BoundSql countBoundSql = new BoundSql(configuration, countSql, boundSql.getParameterMappings(), boundSql.getParameterObject()); // 构建parameterHandler用于设置sql参数 ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), countBoundSql); // 设置sql参数 parameterHandler.setParameters(pstmt); //执行查询 ResultSet rs = pstmt.executeQuery(); while (rs.next()) { total = rs.getLong("total"); } } catch (SQLException e) { e.printStackTrace(); } finally { if (pstmt != null) { try { pstmt.close(); } catch (SQLException e) { e.printStackTrace(); } } } // 返回总数据数 return total; } /** * 设置总数据数、总页数 * * @param page * @param invocation * @param metaObject * @param boundSql */ private void setTotalToPage(Page page, Invocation invocation, MetaObject metaObject, BoundSql boundSql) { // 总数据数 long total = getTotal(invocation, metaObject, boundSql); // 计算总页数 Integer totalPage = (int) (total/page.getPageSize()); if (total % page.getPageSize() != 0) { totalPage = totalPage + 1; } page.setTotal(total); page.setPages(totalPage); } /** * 校验分页参数 * * @param page */ private void checkPage(Page page) { // 如果当前页码大于总页数,抛出异常 if (page.getPageNum() > page.getPages()) { throw new RuntimeException("当前页码[" + page.getPageNum() + "]大于总页数[" + page.getPages() + "]"); } // 如果当前页码小于总页数,抛出异常 if (page.getPageNum() < 1) { throw new RuntimeException("当前页码[" + page.getPageNum() + "]小于[1]"); } } /** * 修改当前查询的sql * * @param invocation * @param metaObject * @param boundSql * @param page * @return */ private Object changeSql(Invocation invocation, MetaObject metaObject, BoundSql boundSql, Page page) throws Exception { // 获取当前查询的sql String sql = getSql(metaObject); // 修改sql,$_paging_table_limit取名是为了防止数据库表重名 String newSql = "select * from (" + sql + ") $_paging_table_limit limit ?, ?"; // 设置当前sql为修改后的sql setSql(metaObject, newSql); // 获取PreparedStatement对象 PreparedStatement pstmt = (PreparedStatement) invocation.proceed(); // 获取sql的总参数个数 int parameCount = pstmt.getParameterMetaData().getParameterCount(); // 设置分页参数 pstmt.setInt(parameCount - 1, (page.getPageNum() - 1) * page.getPageSize()); pstmt.setInt(parameCount, page.getPageSize()); return pstmt; } /** * 获取当前查询的sql * * @param metaObject * @return */ private String getSql(MetaObject metaObject) { return (String) metaObject.getValue("delegate.boundSql.sql"); } /** * 设置当前查询的sql * * @param metaObject */ private void setSql(MetaObject metaObject, String sql) { metaObject.setValue("delegate.boundSql.sql", sql); }}
从上面的分析,MyBatis 插件就是对 ParameterHandler、ResultSetHandler、StatementHandler、Executor 这四个接口上的方法进行拦截,利用JDK动态代理机制,为这些接口的实现类创建代理对象,由代理对象的执行真正的逻辑。这里需要对 JDK 动态代理,以及上面说的四个接口需要有一定了解。总结 MyBatis 插件实现的步骤主要是:
另外,如果配置了多个拦截器的话,会出现层层代理的情况,即代理对象代理了另外一个代理对象,形成一个代理链,执行的时候,也是层层执行。不过需要注意以下两点:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。