This commit is contained in:
XSWL1018 2024-09-09 21:24:13 +08:00
parent 59208c9c48
commit e60bdb6407
20 changed files with 958 additions and 17 deletions

View File

@ -0,0 +1,20 @@
package com.ruoyi.common.annotation.sql;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import com.ruoyi.common.enums.DataSecurityStrategy;
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataSecurity {
public DataSecurityStrategy strategy() default DataSecurityStrategy.CREEATE_BY;
public String table() default "";
public String joinTableAlise() default "";
}

View File

@ -0,0 +1,14 @@
package com.ruoyi.common.annotation.sql;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface MybatisHandlerOrder {
public int value() default 0;
}

View File

@ -0,0 +1,48 @@
package com.ruoyi.common.context.dataSecurity;
import java.util.List;
import java.util.Map;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.ruoyi.common.enums.SqlType;
import com.ruoyi.common.model.JoinTableModel;
import com.ruoyi.common.model.WhereModel;
public class DataSecurityContextHolder {
private static final ThreadLocal<JSONObject> DATA_SECURITY_SQL_CONTEXT_HOLDER = new ThreadLocal<>();
public static void startDataSecurity() {
JSONObject jsonObject = new JSONObject();
jsonObject.put("isSecurity", Boolean.TRUE);
jsonObject.put(SqlType.WHERE.getSqlType(), new JSONArray());
jsonObject.put(SqlType.JOIN.getSqlType(), new JSONArray());
DATA_SECURITY_SQL_CONTEXT_HOLDER.set(jsonObject);
}
public static void addWhereParam(WhereModel whereModel) {
DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.WHERE.getSqlType()).add(whereModel);
}
public static void clearCache() {
DATA_SECURITY_SQL_CONTEXT_HOLDER.remove();
}
public static boolean isSecurity() {
return DATA_SECURITY_SQL_CONTEXT_HOLDER.get() != null
&& DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getBooleanValue("isSecurity");
}
public static JSONArray getWhere() {
return DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.WHERE.getSqlType());
}
public static void addJoinTable(JoinTableModel joinTableModel) {
DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.JOIN.getSqlType()).add(joinTableModel);
}
public static JSONArray getJoinTables() {
return DATA_SECURITY_SQL_CONTEXT_HOLDER.get().getJSONArray(SqlType.JOIN.getSqlType());
}
}

View File

@ -0,0 +1,44 @@
package com.ruoyi.common.context.page;
import com.alibaba.fastjson2.JSONObject;
import com.ruoyi.common.context.page.model.PageInfo;
public class PageContextHolder {
private static final ThreadLocal<JSONObject> PAGE_CONTEXT_HOLDER = new ThreadLocal<>();
private static final String PAGE_FLAG = "isPage";
private static final String PAGE_INFO = "pageInfo";
private static final String TOTAL = "total";
public static void startPage() {
JSONObject jsonObject = new JSONObject();
jsonObject.put(PAGE_FLAG, Boolean.TRUE);
PAGE_CONTEXT_HOLDER.set(jsonObject);
}
public static void setPageInfo() {
PAGE_CONTEXT_HOLDER.get().put(PAGE_INFO, PageInfo.defaultPageInfo());
}
public static PageInfo getPageInfo() {
return (PageInfo) PAGE_CONTEXT_HOLDER.get().get(PAGE_INFO);
}
public static void clear() {
PAGE_CONTEXT_HOLDER.remove();
}
public static boolean isPage() {
return PAGE_CONTEXT_HOLDER.get() != null && PAGE_CONTEXT_HOLDER.get().getBooleanValue(PAGE_FLAG);
}
public static void setTotal(Long total) {
PAGE_CONTEXT_HOLDER.get().put(TOTAL, total);
}
public static Long getTotal() {
return PAGE_CONTEXT_HOLDER.get().getLong(TOTAL);
}
}

View File

@ -0,0 +1,63 @@
package com.ruoyi.common.context.page.model;
import com.ruoyi.common.core.text.Convert;
import com.ruoyi.common.utils.ServletUtils;
public class PageInfo {
private Long pageNumber;
private Long pageSize;
/**
* 当前记录起始索引
*/
public static final String PAGE_NUM = "pageNum";
/**
* 每页显示记录数
*/
public static final String PAGE_SIZE = "pageSize";
/**
* 排序列
*/
public static final String ORDER_BY_COLUMN = "orderByColumn";
/**
* 排序的方向 "desc" 或者 "asc".
*/
public static final String IS_ASC = "isAsc";
/**
* 分页参数合理化
*/
public static final String REASONABLE = "reasonable";
public Long getPageNumber() {
return pageNumber;
}
public void setPageNumber(Long pageNumber) {
this.pageNumber = pageNumber;
}
public Long getPageSize() {
return pageSize;
}
public void setPageSize(Long pageSize) {
this.pageSize = pageSize;
}
public static PageInfo defaultPageInfo() {
PageInfo pageInfo = new PageInfo();
pageInfo.setPageNumber(Long.valueOf(Convert.toInt(ServletUtils.getParameter(PAGE_NUM), 1)));
pageInfo.setPageSize(Long.valueOf(Convert.toInt(ServletUtils.getParameter(PAGE_SIZE), 10)));
return pageInfo;
}
public Long getOffeset() {
return (pageNumber.longValue() - 1L) * pageSize;
}
}

View File

@ -0,0 +1,25 @@
package com.ruoyi.common.context.page.model;
import java.util.List;
public class RuoyiTableData {
private Long total;
private List<?> data;
public Long getTotal() {
return total;
}
public void setTotal(Long total) {
this.total = total;
}
public List<?> getData() {
return data;
}
public void setData(List<?> data) {
this.data = data;
}
}

View File

@ -0,0 +1,22 @@
package com.ruoyi.common.context.page.model;
import java.util.ArrayList;
import java.util.List;
public class TableInfo<E> extends ArrayList<E> {
private Long total;
public TableInfo(List<? extends E> list) {
super(list);
}
public Long getTotal() {
return total;
}
public void setTotal(Long total) {
this.total = total;
}
}

View File

@ -0,0 +1,8 @@
package com.ruoyi.common.enums;
public enum DataSecurityStrategy {
JOINTABLE_CREATE_BY,
JOINTABLE_USER_ID,
CREEATE_BY,
USER_ID;
}

View File

@ -0,0 +1,18 @@
package com.ruoyi.common.enums;
public enum SqlType {
WHERE("where"),
JOIN("join"),
SELECT("select"),
LIMIT("limit");
private String sqlType;
public String getSqlType() {
return sqlType;
}
private SqlType(String sqlType) {
this.sqlType = sqlType;
}
}

View File

@ -0,0 +1,7 @@
package com.ruoyi.common.handler.sql;
public interface MybatisAfterHandler {
Object handleObject(Object object) throws Throwable;
}

View File

@ -0,0 +1,15 @@
package com.ruoyi.common.handler.sql;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
public interface MybatisPreHandler {
void preHandle(Executor executor, MappedStatement mappedStatement, Object params,
RowBounds rowBounds, ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql)
throws Throwable;
}

View File

@ -0,0 +1,100 @@
package com.ruoyi.common.handler.sql.dataSecurity;
import java.lang.reflect.Field;
import java.util.List;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;
import org.springframework.util.ReflectionUtils;
import com.ruoyi.common.annotation.sql.MybatisHandlerOrder;
import com.ruoyi.common.context.dataSecurity.DataSecurityContextHolder;
import com.ruoyi.common.handler.sql.MybatisPreHandler;
import com.ruoyi.common.model.JoinTableModel;
import com.ruoyi.common.model.WhereModel;
import com.ruoyi.common.utils.StringUtils;
import com.ruoyi.common.utils.sql.SqlUtil;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
@MybatisHandlerOrder(1)
@Component
public class DataSecurityPreHandler implements MybatisPreHandler {
private static final Field sqlFiled = ReflectionUtils.findField(BoundSql.class, "sql");
static {
sqlFiled.setAccessible(true);
}
@Override
public void preHandle(Executor executor, MappedStatement mappedStatement, Object params, RowBounds rowBounds,
ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws Throwable {
if (DataSecurityContextHolder.isSecurity()) {
Statement sql = parseSql(SqlUtil.parseSql(boundSql.getSql()));
sqlFiled.set(boundSql, sql.toString());
}
}
private static Statement parseSql(Statement statement) throws JSQLParserException {
if (statement instanceof Select) {
Select select = (Select) statement;
// plain.setWhere(CCJSqlParserUtil.parseCondExpression(handleWhere(expWhere)));
handleWhere(select);
handleJoin(select);
return select;
} else {
return statement;
}
}
private static void handleWhere(Select select) throws JSQLParserException {
PlainSelect plain = select.getPlainSelect();
Expression expWhere = plain.getWhere();
StringBuilder whereParam = new StringBuilder(" ");
String where = expWhere != null ? expWhere.toString() : null;
if (DataSecurityContextHolder.getWhere() == null || DataSecurityContextHolder.getWhere().size() <= 0) {
return;
}
DataSecurityContextHolder.getWhere().forEach(item -> {
whereParam.append(((WhereModel) item).getSqlString());
});
where = StringUtils.isEmpty(where) ? whereParam.toString().substring(5, whereParam.length())
: where + " " + whereParam.toString();
plain.setWhere(CCJSqlParserUtil.parseCondExpression(where));
}
private static void handleJoin(Select select) {
PlainSelect selectBody = select.getPlainSelect();
if (DataSecurityContextHolder.getJoinTables() == null || DataSecurityContextHolder.getJoinTables().size() <= 0) {
return;
}
DataSecurityContextHolder.getJoinTables().forEach(item -> {
JoinTableModel tableModel = (JoinTableModel) item;
Table table = new Table(tableModel.getJoinTable());
table.setAlias(new Alias(tableModel.getJoinTableAlise()));
Join join = new Join();
join.setRightItem(table);
join.setInner(true);
Expression onExpression = new EqualsTo(new Column(tableModel.getFromTableColumnString()),
new Column(tableModel.getJoinTableColumnString()));
join.setOnExpressions(List.of(onExpression));
selectBody.addJoins(join);
});
}
}

View File

@ -0,0 +1,31 @@
package com.ruoyi.common.handler.sql.page;
import java.util.ArrayList;
import java.util.List;
import org.springframework.stereotype.Component;
import com.ruoyi.common.annotation.sql.MybatisHandlerOrder;
import com.ruoyi.common.context.page.PageContextHolder;
import com.ruoyi.common.context.page.model.TableInfo;
import com.ruoyi.common.handler.sql.MybatisAfterHandler;
@MybatisHandlerOrder(1)
@Component
public class PageAfterHandler implements MybatisAfterHandler {
@Override
public Object handleObject(Object object) throws Throwable {
if (PageContextHolder.isPage()) {
if (object instanceof List) {
TableInfo tableInfo = new TableInfo<>((List) object);
tableInfo.setTotal(PageContextHolder.getTotal());
PageContextHolder.clear();
return tableInfo;
}
return object;
}
return object;
}
}

View File

@ -0,0 +1,142 @@
package com.ruoyi.common.handler.sql.page;
import java.lang.reflect.Field;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ResultMap;
import org.apache.ibatis.mapping.ResultMapping;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;
import org.springframework.util.ReflectionUtils;
import com.ruoyi.common.annotation.sql.MybatisHandlerOrder;
import com.ruoyi.common.context.page.PageContextHolder;
import com.ruoyi.common.context.page.model.PageInfo;
import com.ruoyi.common.handler.sql.MybatisPreHandler;
import com.ruoyi.common.utils.sql.SqlUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.Limit;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectItem;
@Component
@MybatisHandlerOrder(2)
public class PagePreHandler implements MybatisPreHandler {
private static final List<ResultMapping> EMPTY_RESULTMAPPING = new ArrayList<ResultMapping>(0);
private static final String SELECT_COUNT_SUFIX = "_SELECT_COUNT";
private static final Field sqlFiled = ReflectionUtils.findField(BoundSql.class, "sql");
static {
sqlFiled.setAccessible(true);
}
@Override
public void preHandle(Executor executor, MappedStatement mappedStatement, Object params, RowBounds rowBounds,
ResultHandler resultHandler, CacheKey cacheKey, BoundSql boundSql) throws Throwable {
if (PageContextHolder.isPage()) {
String originSql = boundSql.getSql();
Statement sql = SqlUtil.parseSql(originSql);
if (sql instanceof Select) {
PageInfo pageInfo = PageContextHolder.getPageInfo();
Statement handleLimit = handleLimit((Select) sql, pageInfo);
Statement countSql = getCountSql((Select) sql);
Long count = getCount(executor, mappedStatement, params, boundSql, rowBounds, resultHandler,
countSql.toString());
PageContextHolder.setTotal(count);
sqlFiled.set(boundSql, handleLimit.toString());
cacheKey = executor.createCacheKey(mappedStatement, params, rowBounds, boundSql);
}
}
}
private static MappedStatement createCountMappedStatement(MappedStatement ms, String newMsId) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), newMsId,
ms.getSqlSource(),
ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
StringBuilder keyProperties = new StringBuilder();
for (String keyProperty : ms.getKeyProperties()) {
keyProperties.append(keyProperty).append(",");
}
keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
builder.keyProperty(keyProperties.toString());
}
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
// count查询返回值int
List<ResultMap> resultMaps = new ArrayList<ResultMap>();
ResultMap resultMap = new ResultMap.Builder(ms.getConfiguration(), ms.getId(), Long.class,
EMPTY_RESULTMAPPING)
.build();
resultMaps.add(resultMap);
builder.resultMaps(resultMaps);
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
public static Long getCount(Executor executor, MappedStatement mappedStatement, Object parameter,
BoundSql boundSql, RowBounds rowBounds, ResultHandler resultHandler, String countSql)
throws SQLException {
Map<String, Object> additionalParameters = boundSql.getAdditionalParameters();
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql,
boundSql.getParameterMappings(), parameter);
for (String key : additionalParameters.keySet()) {
countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
}
CacheKey countKey = executor.createCacheKey(mappedStatement, parameter, RowBounds.DEFAULT, countBoundSql);
List<Object> query = executor.query(
createCountMappedStatement(mappedStatement, getCountMSId(mappedStatement)),
parameter, RowBounds.DEFAULT, resultHandler, countKey,
countBoundSql);
return (Long) query.get(0);
}
private static String getCountMSId(MappedStatement mappedStatement) {
return mappedStatement.getId() + SELECT_COUNT_SUFIX;
}
public static Statement getCountSql(Select select) {
PlainSelect plain = select.getPlainSelect();
PlainSelect countPlain = new PlainSelect();
countPlain.setSelectItems(List.of(new SelectItem<>(new Column("COUNT(0)"))));
countPlain.setJoins(plain.getJoins());
countPlain.setWhere(plain.getWhere());
countPlain.setFromItem(plain.getFromItem());
countPlain.setDistinct(plain.getDistinct());
countPlain.setHaving(plain.getHaving());
countPlain.setIntoTables(plain.getIntoTables());
// countPlain.setOrderByElements(plain.getOrderByElements());
return plain;
}
private static Statement handleLimit(Select select, PageInfo pageInfo) {
Limit limit = new Limit();
limit.setRowCount(new Column(pageInfo.getPageSize().toString()));
limit.setOffset(new Column(pageInfo.getOffeset().toString()));
PlainSelect plain = select.getPlainSelect();
plain.setLimit(limit);
return select;
}
}

View File

@ -0,0 +1,85 @@
package com.ruoyi.common.model;
import com.ruoyi.common.utils.StringUtils;
public class JoinTableModel {
private String joinTable;
private String joinTableAlise;
private String fromTable;
private String fromTableAlise;
private String joinTableColumn;
private String fromTableColumn;
public String getJoinTable() {
return joinTable;
}
public void setJoinTable(String joinTable) {
this.joinTable = joinTable;
}
public String getJoinTableAlise() {
if (StringUtils.isEmpty(this.joinTableAlise)) {
return this.joinTable;
}
return joinTableAlise;
}
public void setJoinTableAlise(String joinTableAlise) {
this.joinTableAlise = joinTableAlise;
}
public String getFromTable() {
return fromTable;
}
public void setFromTable(String fromTable) {
this.fromTable = fromTable;
}
public String getFromTableAlise() {
if (StringUtils.isEmpty(this.fromTableAlise)) {
return this.fromTable;
}
return fromTableAlise;
}
public void setFromTableAlise(String fromTableAlise) {
this.fromTableAlise = fromTableAlise;
}
public String getJoinTableColumn() {
return joinTableColumn;
}
public void setJoinTableColumn(String joinTableColumn) {
this.joinTableColumn = joinTableColumn;
}
public String getFromTableColumn() {
return fromTableColumn;
}
public void setFromTableColumn(String fromTableColumn) {
this.fromTableColumn = fromTableColumn;
}
public String getJoinTableColumnString() {
return this.getJoinTableAlise() + "." + this.joinTableColumn;
}
public String getFromTableColumnString() {
if (StringUtils.isEmpty(this.getFromTableAlise())) {
return this.fromTableColumn;
}
return this.getFromTableAlise() + "." + this.fromTableColumn;
}
}

View File

@ -0,0 +1,67 @@
package com.ruoyi.common.model;
import com.ruoyi.common.utils.StringUtils;
public class WhereModel {
private String whereColumn;
private String table;
private Object value;
private String connectType;
private String method;
public static final String METHOD_EQUAS = "=";
public static final String METHOD_LIKE = "like";
public static final String CONNECT_AND = "AND";
public static final String CONNECT_OR = "OR";
public String getWhereColumn() {
return whereColumn;
}
public void setWhereColumn(String whereColumn) {
this.whereColumn = whereColumn;
}
public String getTable() {
return table;
}
public void setTable(String table) {
this.table = table;
}
public Object getValue() {
return value;
}
public void setValue(Object value) {
this.value = value;
}
public String getFullTableColumn() {
if (StringUtils.isEmpty(this.table)) {
return this.whereColumn;
}
return this.table + "." + this.whereColumn;
}
public String getConnectType() {
return connectType;
}
public void setConnectType(String connectType) {
this.connectType = connectType;
}
public String getMethod() {
return method;
}
public void setMethod(String method) {
this.method = method;
}
public String getSqlString() {
return String.format(" %s %s %s %s ", this.getConnectType(), this.getFullTableColumn(), this.method, this.value);
}
}

View File

@ -0,0 +1,14 @@
package com.ruoyi.common.utils;
import com.ruoyi.common.context.dataSecurity.DataSecurityContextHolder;
public class DataSecurityUtil {
public static void closeDataSecurity() {
DataSecurityContextHolder.clearCache();
}
public static void startDataSecurity() {
DataSecurityContextHolder.startDataSecurity();
}
}

View File

@ -1,32 +1,37 @@
package com.ruoyi.common.utils.sql;
import java.io.StringReader;
import com.ruoyi.common.exception.UtilException;
import com.ruoyi.common.utils.StringUtils;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.statement.Statement;
/**
* sql操作工具类
*
* @author ruoyi
*/
public class SqlUtil
{
public class SqlUtil {
/**
* 定义常用的 sql关键字
*/
public static String SQL_REGEX = "and |extractvalue|updatexml|sleep|exec |insert |select |delete |update |drop |count |chr |mid |master |truncate |char |declare |or |union |like |+|/*|user()";
public static String SQL_REGEX = "and |extractvalue|updatexml|exec |insert |select |delete |update |drop |count |chr |mid |master |truncate |char |declare |or |+|user()";
/**
* 仅支持字母数字下划线空格逗号小数点支持多个字段排序
*/
public static String SQL_PATTERN = "[a-zA-Z0-9_\\ \\,\\.]+";
private static final CCJSqlParserManager parserManager = new CCJSqlParserManager();
/**
* 检查字符防止注入绕过
*/
public static String escapeOrderBySql(String value)
{
if (StringUtils.isNotEmpty(value) && !isValidOrderBySql(value))
{
public static String escapeOrderBySql(String value) {
if (StringUtils.isNotEmpty(value) && !isValidOrderBySql(value)) {
throw new UtilException("参数不符合规范,不能进行查询");
}
return value;
@ -35,27 +40,26 @@ public class SqlUtil
/**
* 验证 order by 语法是否符合规范
*/
public static boolean isValidOrderBySql(String value)
{
public static boolean isValidOrderBySql(String value) {
return value.matches(SQL_PATTERN);
}
/**
* SQL关键字检查
*/
public static void filterKeyword(String value)
{
if (StringUtils.isEmpty(value))
{
public static void filterKeyword(String value) {
if (StringUtils.isEmpty(value)) {
return;
}
String[] sqlKeywords = StringUtils.split(SQL_REGEX, "\\|");
for (String sqlKeyword : sqlKeywords)
{
if (StringUtils.indexOfIgnoreCase(value, sqlKeyword) > -1)
{
for (String sqlKeyword : sqlKeywords) {
if (StringUtils.indexOfIgnoreCase(value, sqlKeyword) > -1) {
throw new UtilException("参数存在SQL注入风险");
}
}
}
public static Statement parseSql(String sql) throws JSQLParserException {
return parserManager.parse(new StringReader(sql));
}
}

View File

@ -0,0 +1,85 @@
package com.ruoyi.framework.aspectj;
import java.util.List;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.stereotype.Component;
import com.ruoyi.common.annotation.sql.DataSecurity;
import com.ruoyi.common.context.dataSecurity.DataSecurityContextHolder;
import com.ruoyi.common.enums.DataSecurityStrategy;
import com.ruoyi.common.model.JoinTableModel;
import com.ruoyi.common.model.WhereModel;
import com.ruoyi.common.utils.SecurityUtils;
import com.ruoyi.common.utils.StringUtils;
import ch.qos.logback.core.util.StringUtil;
@Aspect
@Component
public class DataSecurityAspect {
@Before(value = "@annotation(dataSecurity)")
public void doBefore(final JoinPoint point, DataSecurity dataSecurity) throws Throwable {
DataSecurityContextHolder.startDataSecurity();
switch (dataSecurity.strategy()) {
case CREEATE_BY:
WhereModel createByModel = new WhereModel();
createByModel.setTable(dataSecurity.table());
createByModel.setValue("\"" + SecurityUtils.getUsername() + "\"");
createByModel.setWhereColumn("create_by");
createByModel.setMethod(WhereModel.METHOD_EQUAS);
createByModel.setConnectType(WhereModel.CONNECT_AND);
DataSecurityContextHolder.addWhereParam(createByModel);
break;
case USER_ID:
WhereModel userIdModel = new WhereModel();
userIdModel.setTable(dataSecurity.table());
userIdModel.setTable("user_id");
userIdModel.setValue(SecurityUtils.getUserId());
userIdModel.setConnectType(WhereModel.CONNECT_AND);
userIdModel.setMethod(WhereModel.METHOD_EQUAS);
DataSecurityContextHolder.addWhereParam(userIdModel);
break;
case JOINTABLE_CREATE_BY:
JoinTableModel createByTableModel = new JoinTableModel();
createByTableModel.setFromTable(dataSecurity.table());
createByTableModel.setFromTableAlise(dataSecurity.table());
createByTableModel.setJoinTable("sys_user");
if (!StringUtils.isEmpty(dataSecurity.joinTableAlise())) {
createByTableModel.setJoinTableAlise(dataSecurity.joinTableAlise());
}
createByTableModel.setFromTableColumn("create_by");
createByTableModel.setJoinTableColumn("user_name");
DataSecurityContextHolder.addJoinTable(createByTableModel);
break;
case JOINTABLE_USER_ID:
JoinTableModel userIdTableModel = new JoinTableModel();
userIdTableModel.setFromTable(dataSecurity.table());
userIdTableModel.setFromTableAlise(dataSecurity.table());
userIdTableModel.setJoinTable("sys_user");
if (!StringUtils.isEmpty(dataSecurity.joinTableAlise())) {
userIdTableModel.setJoinTableAlise(dataSecurity.joinTableAlise());
}
userIdTableModel.setFromTableColumn("user_id");
userIdTableModel.setJoinTableColumn("user_id");
DataSecurityContextHolder.addJoinTable(userIdTableModel);
break;
default:
break;
}
}
@After(value = " @annotation(dataSecurity)")
public void doAfter(final JoinPoint point, DataSecurity dataSecurity) {
DataSecurityContextHolder.clearCache();
}
}

View File

@ -0,0 +1,129 @@
package com.ruoyi.framework.interceptor.mybatis;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import com.ruoyi.common.annotation.sql.MybatisHandlerOrder;
import com.ruoyi.common.handler.sql.MybatisAfterHandler;
import com.ruoyi.common.handler.sql.MybatisPreHandler;
import jakarta.annotation.PostConstruct;
@Component
@Intercepts({
@Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class,
RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class }),
@Signature(type = Executor.class, method = "query", args = {
MappedStatement.class, Object.class, RowBounds.class,
ResultHandler.class })
})
public class MybatisInterceptor implements Interceptor {
@Autowired
private List<MybatisPreHandler> preHandlerBeans;
@Autowired
private List<MybatisAfterHandler> afterHandlerBeans;
private static List<MybatisPreHandler> preHandlersChain;
private static List<MybatisAfterHandler> afterHandlersChain;
@PostConstruct
public void init() {
List<MybatisPreHandler> sortedPreHandlers = preHandlerBeans.stream().sorted((item1, item2) -> {
int a;
int b;
MybatisHandlerOrder ann1 = item1.getClass().getAnnotation(MybatisHandlerOrder.class);
MybatisHandlerOrder ann2 = item2.getClass().getAnnotation(MybatisHandlerOrder.class);
if (ann1 == null) {
a = 0;
} else {
a = ann1.value();
}
if (ann2 == null) {
b = 0;
} else {
b = ann2.value();
}
return a - b;
}).collect(Collectors.toList());
preHandlersChain = sortedPreHandlers;
List<MybatisAfterHandler> sortedAfterHandlers = afterHandlerBeans.stream().sorted((item1, item2) -> {
int a;
int b;
MybatisHandlerOrder ann1 = item1.getClass().getAnnotation(MybatisHandlerOrder.class);
MybatisHandlerOrder ann2 = item2.getClass().getAnnotation(MybatisHandlerOrder.class);
if (ann1 == null) {
a = 0;
} else {
a = ann1.value();
}
if (ann2 == null) {
b = 0;
} else {
b = ann2.value();
}
return a - b;
}).collect(Collectors.toList());
afterHandlersChain = sortedAfterHandlers;
}
@Override
public Object intercept(Invocation invocation) throws Throwable {
Executor targetExecutor = (Executor) invocation.getTarget();
Object[] args = invocation.getArgs();
if (args.length < 6) {
if (preHandlersChain != null && preHandlersChain.size() > 0) {
MappedStatement ms = (MappedStatement) args[0];
Object parameterObject = args[1];
RowBounds rowBounds = (RowBounds) args[2];
Executor executor = (Executor) invocation.getTarget();
BoundSql boundSql = ms.getBoundSql(parameterObject);
// 可以对参数做各种处理
CacheKey cacheKey = executor.createCacheKey(ms, parameterObject, rowBounds, boundSql);
for (MybatisPreHandler item : preHandlersChain) {
item.preHandle(targetExecutor, ms, args[1], (RowBounds) args[2],
(ResultHandler) args[3], cacheKey, boundSql);
}
}
Object result = invocation.proceed();
if (afterHandlersChain != null && afterHandlersChain.size() > 0) {
for (MybatisAfterHandler item : afterHandlersChain) {
item.handleObject(result);
}
}
return result;
}
if (preHandlersChain != null && preHandlersChain.size() > 0) {
for (MybatisPreHandler item : preHandlersChain) {
item.preHandle(targetExecutor, (MappedStatement) args[0], args[1], (RowBounds) args[2],
(ResultHandler) args[3], (CacheKey) args[4], (BoundSql) args[5]);
}
}
Object result = invocation.proceed();
if (afterHandlersChain != null && afterHandlersChain.size() > 0) {
for (MybatisAfterHandler item : afterHandlersChain) {
result = item.handleObject(result);
}
}
return result;
}
}