@@ -1,11 +1,16 @@
package cn.iocoder.yudao.framework.datapermission.core.interceptor ;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper ;
import cn.hutool.core.collection.CollUtil ;
import cn.iocoder.yudao.framework.common.util.collection.SetUtils ;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRule ;
import cn.iocoder.yudao.framework.datapermission.core.rule.DataPermissionRuleFactory ;
import com.alibaba.ttl.TransmittableThreadLocal ;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils ;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils ;
import com.baomidou.mybatisplus.core.toolkit.StringPool ;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport ;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor ;
import lombok.RequiredArgsConstructor ;
import net.sf.jsqlparser.expression.* ;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression ;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression ;
@@ -24,33 +29,58 @@ import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds ;
import java.sql.Connection ;
import java.util.Collection ;
import java.util.Deque ;
import java.util.LinkedList ;
import java.util.List ;
import java.util.* ;
import java.util.concurrent.ConcurrentHashMap ;
@RequiredArgsConstructor
public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {
// private TenantLineHandler tenantLineHandler ;
private final DataPermissionRuleFactory ruleFactory ;
@Override
private final MappedStatementCache mappedStatementCache = new MappedStatementCache ( ) ;
@Override // SELECT 场景
public void beforeQuery ( Executor executor , MappedStatement ms , Object parameter , RowBounds rowBounds , ResultHandler resultHandler , BoundSql boundSql ) {
// TODO 芋艿:这个判断,后续读懂下
if ( InterceptorIgnoreHelper . willIgnoreTenantLin e( ms . getId ( ) ) ) return ;
// 获得 Mapper 对应的数据权限的规则
List < DataPermissionRule > rules = ruleFactory . getDataPermissionRul e( ms . getId ( ) ) ;
if ( mappedStatementCache . noRewritable ( ms , rules ) ) { // 如果无需重写,则跳过
return ;
}
PluginUtils . MPBoundSql mpBs = PluginUtils . mpBoundSql ( boundSql ) ;
// TODO 芋艿: null=》DataScope
mpBs . sql ( parserSingle ( mpBs . sql ( ) , null ) ) ;
try {
// 初始化上下文
ContextHolder . init ( rules ) ;
// 处理 SQL
mpBs . sql ( parserSingle ( mpBs . sql ( ) , null ) ) ;
} finally {
addMappedStatementCache ( ms ) ;
ContextHolder . clear ( ) ;
}
}
@Override
@Override // 只处理 UPDATE / DELETE 场景
public void beforePrepare ( StatementHandler sh , Connection connection , Integer transactionTimeout ) {
PluginUtils . MPStatementHandler mpSh = PluginUtils . mpStatementHandler ( sh ) ;
MappedStatement ms = mpSh . mappedStatement ( ) ;
SqlCommandType sct = ms . getSqlCommandType ( ) ;
if ( sct = = SqlCommandType . UPDATE | | sct = = SqlCommandType . DELETE ) { // 无需处理 Insert 语句
if ( InterceptorIgnoreHelper . willIgnoreTenantLine ( ms . getId ( ) ) ) return ;
// 获得 Mapper 对应的数据权限的规则
List < DataPermissionRule > rules = ruleFactory . getDataPermissionRule ( ms . getId ( ) ) ;
if ( mappedStatementCache . noRewritable ( ms , rules ) ) { // 如果无需重写,则跳过
return ;
}
PluginUtils . MPBoundSql mpBs = mpSh . mPBoundSql ( ) ;
mpBs . sql ( parserMulti ( mpBs . sql ( ) , null ) ) ;
try {
// 初始化上下文
ContextHolder . init ( rules ) ;
// 处理 SQL
mpBs . sql ( parserMulti ( mpBs . sql ( ) , null ) ) ;
} finally {
addMappedStatementCache ( ms ) ;
ContextHolder . clear ( ) ;
}
}
}
@@ -87,10 +117,6 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
@Override
protected void processUpdate ( Update update , int index , String sql , Object obj ) {
final Table table = update . getTable ( ) ;
if ( ignoreTable ( table . getName ( ) ) ) {
// 过滤退出执行
return ;
}
update . setWhere ( this . andExpression ( table , update . getWhere ( ) ) ) ;
}
@@ -99,10 +125,6 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
*/
@Override
protected void processDelete ( Delete delete , int index , String sql , Object obj ) {
if ( ignoreTable ( delete . getTable ( ) . getName ( ) ) ) {
// 过滤退出执行
return ;
}
delete . setWhere ( this . andExpression ( delete . getTable ( ) , delete . getWhere ( ) ) ) ;
}
@@ -378,4 +400,116 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
return new LongValue ( 1L ) ;
}
/**
* 判断 SQL 是否重写。如果没有重写,则添加到 {@link MappedStatementCache} 中
*
* @param ms MappedStatement
*/
private void addMappedStatementCache ( MappedStatement ms ) {
if ( ContextHolder . getRewrite ( ) ) {
return ;
}
// 有重写,进行添加
mappedStatementCache . addNoRewritable ( ms , ContextHolder . getRules ( ) ) ;
}
/**
* SQL 解析上下文,方便透传 {@link DataPermissionRule} 规则
*
* @author 芋道源码
*/
private static final class ContextHolder {
/**
* 该 {@link MappedStatement} 对应的规则
*/
private static final ThreadLocal < List < DataPermissionRule > > RULES = new TransmittableThreadLocal < > ( ) ;
/**
* SQL 是否进行重写
*/
private static final ThreadLocal < Boolean > REWRITE = new TransmittableThreadLocal < > ( ) ;
public static void init ( List < DataPermissionRule > rules ) {
RULES . set ( rules ) ;
REWRITE . set ( false ) ;
}
public static void clear ( ) {
RULES . remove ( ) ;
REWRITE . remove ( ) ;
}
public static boolean getRewrite ( ) {
return REWRITE . get ( ) ;
}
public static void setRewrite ( boolean rewrite ) {
REWRITE . set ( rewrite ) ;
}
public static List < DataPermissionRule > getRules ( ) {
return RULES . get ( ) ;
}
}
/**
* {@link MappedStatement} 缓存
* 目前主要用于,记录 {@link DataPermissionRule} 是否对指定 {@link MappedStatement} 无效
* 如果无效,则可以避免 SQL 的解析,加快速度
*
* @author 芋道源码
*/
private static final class MappedStatementCache {
/**
* 无需重写的映射
*
* value: {@link MappedStatement#getId()} 编号
*/
private final Map < Class < ? extends DataPermissionRule > , Set < String > > noRewritableMappedStatements = new ConcurrentHashMap < > ( ) ;
/**
* 判断是否无需重写
* ps: 虽然有点中文式英语, 但是容易读懂即可
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
* @return 是否无需重写
*/
public boolean noRewritable ( MappedStatement ms , List < DataPermissionRule > rules ) {
// 如果规则为空,说明无需重写
if ( CollUtil . isEmpty ( rules ) ) {
return true ;
}
// 任一规则不在 noRewritableMap 中,则说明可能需要重写
for ( DataPermissionRule rule : rules ) {
Set < String > mappedStatementIds = noRewritableMappedStatements . get ( rule . getClass ( ) ) ;
if ( ! CollUtil . contains ( mappedStatementIds , ms . getId ( ) ) ) { // 不存在,则说明可能要重写
return false ;
}
}
return true ;
}
/**
* 添加无需重写的 MappedStatement
*
* @param ms MappedStatement
* @param rules 数据权限规则数组
*/
public void addNoRewritable ( MappedStatement ms , List < DataPermissionRule > rules ) {
for ( DataPermissionRule rule : rules ) {
Set < String > mappedStatementIds = noRewritableMappedStatements . get ( rule . getClass ( ) ) ;
if ( CollUtil . isNotEmpty ( mappedStatementIds ) ) {
mappedStatementIds . add ( ms . getId ( ) ) ;
} else {
noRewritableMappedStatements . put ( rule . getClass ( ) , SetUtils . asSet ( ms . getId ( ) ) ) ;
}
}
}
}
}