基於 SQL 解析的 JPA 多租戶方案

Createsequence發表於2024-03-25

概述

最近在對一個使用 JPA 的老專案進行多租戶改造,由於年代過於久遠,陳年屎山讓人實在不敢輕舉妄動,最後只能選擇一個改造成本最小的方案,那就是透過攔截器改 SQL,動態新增租戶 ID 作為查詢條件。
本篇文章用於記錄筆者基於該方案解決此問題的踩坑和思考過程,部分程式碼與實際程式碼有所出入。如果希望直接獲取可執行的程式碼,可以直接在 github 倉庫獲取。

1.SQL 攔截器

由於 JPA 底層是基於 Hibernate 實現的,而 Hibernate 本身提供了 StatementInspector 介面用於實現 SQL 攔截。因此我們只需要在這個階段對 SQL 進行解析,然後為需要按租戶進行隔離的資源表動態的新增租戶 ID 的過濾條件即可。
這裡我們選擇使用 JSqlParser 作為我們的 SQL 解析器。它社群還算活躍,文件詳細,最重要的是,API 比較簡單易懂。下文若不特意澄清,則所有與 SQL 解析相關的類都來自於它。

1.1.簡單實現

在最開始,我們寫一個簡單的實現來驗證一下可行性。
假設,我們需要指定攔截針對表 t_resource 的查詢語句,為其新增 tenant_id = xxx 作為查詢條件,那麼這個 SQL 攔截器需要做到:

  1. 將 SQL 解析為 Statement 物件,然後檢查其是否為查詢類 SQL;
  2. 獲取 SQL 的 from 語句,並判斷查詢的表是否為我們要攔截的表;
  3. 解析 where 語句:
    • 若原本沒有任何條件,則為其生成一個 where t.tenant_id = xxx 的條件;
    • 如果原本已經有條件了,則為其在最後拼接 and t.tenant_id = xxx 的條件;

這裡我們針對這個需求給出一個簡單的實現:

@Slf4j
public class TenantSQLInterceptor {

    /**
     * 處理SQL語句
     *
     * @param sql SQL語句
     * @param table 要攔截器的租戶表名
     * @param column 租戶欄位名
     * @param value 租戶欄位值
     * @return 處理後的SQL語句
     */
    public String handle(String sql, String table, String column, String value) {
        log.debug("租戶攔截器攔截原始 SQL: {}", sql);
        String handledSql = doHandle(sql, table, column, value);
        log.info("租戶攔截器攔截後 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    /**
     * 處理SQL語句
     *
     * @param sql SQL語句
     * @return 處理後的SQL語句
     */
    @Nullable
    public String doHandle(String sql, String table, String column, String value) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(statement -> doHandle(statement, table, column, value))
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
            return statements;
        } catch (JSQLParserException e) {
            log.error("SQL 解析失敗: {}", sql, e);
            throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失敗");
        }
    }

    private Statement doHandle(Statement statement, String table, String column, String value) {
        if (!(statement instanceof Select)) {
            return statement;
        }
        try {
            SelectBody selectBody = ((Select) statement).getSelectBody();
            // 目前只處理普通的 SQL 查詢
            if (selectBody instanceof PlainSelect) {
                PlainSelect plainSelect = (PlainSelect) selectBody;
                FromItem fromItem = plainSelect.getFromItem();
                Expression where = plainSelect.getWhere();

                // 如果查詢的表即為要攔截的租戶表,則為查詢條件新增租戶條件
                if (fromItem instanceof Table) {
                    String queryTable = ((Table) fromItem).getName();
                    if (Objects.equals(queryTable, table)) {
                        where = appendTenantCondition(plainSelect.getWhere(), fromItem, value, column);
                        plainSelect.setWhere(where);
                    }
                }
            }
        } catch (Exception ex) {
            log.error("SQL 處理失敗: {}", statement, ex);
            throw new RuntimeException("SQL 處理失敗", ex);
        }
        return statement;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        // 生成一個 tenant_id = xxx 的條件
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在別名,則欄位應該變“表別名.欄位名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }
}

測試一下:

public static void main(String[] args) {
    String sql = "select * from t_resource r where r.order = 1";
    TenantSQLInterceptor tenantSQLInterceptor = new TenantSQLInterceptor();
    String handledSql = tenantSQLInterceptor.handle(sql, "t_resource", "tenant_id", "1");
    System.out.println(handledSql); // = SELECT * FROM resource r WHERE r.t_resource = 1 AND r.tenant_id = '1'
}

雖然還非常簡陋,不過這個攔截器已經能夠初步實現我們想要的功能了,不過要投入實際場景,顯然還需要做出“一點點”改進。

1.2.從上下文獲取租戶資訊

首先,真實的使用場景中,一個 SQL 可能會同時涉及到多張需要攔截器的表,並且每張表對應的租戶 ID 仍然有可能不同,因此我們最好直接將相關的配置資訊提取出來,改為透過一個上下文物件進行獲取:

@Slf4j
public class TenantSQLInterceptor {

    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 設定租戶資訊
     *
     * @param tenantInfo 租戶資訊
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租戶資訊
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }


    public String handle(String sql) {
        // 如果未設定租戶資訊,則直接返回原始SQL
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        if (Objects.isNull(tenantInfo)) {
            return sql;
        }
        log.debug("租戶攔截器攔截原始 SQL: {}", sql);
        String handledSql = doHandle(sql);
        log.info("租戶攔截器攔截後 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    /**
     * 處理SQL語句
     *
     * @param sql SQL語句
     * @return 處理後的SQL語句
     */
    @Nullable
    public String doHandle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
            return statements;
        } catch (JSQLParserException e) {
            log.error("SQL 解析失敗: {}", sql, e);
            throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失敗");
        }
    }

    private Statement doHandle(Statement statement) {
        if (!(statement instanceof Select)) {
            return statement;
        }
        try {
            SelectBody selectBody = ((Select) statement).getSelectBody();
            if (selectBody instanceof PlainSelect) {
                PlainSelect plainSelect = (PlainSelect) selectBody;
                FromItem fromItem = plainSelect.getFromItem();
                Expression where = plainSelect.getWhere();

                // 如果查詢的表即為要攔截的租戶表,則為查詢條件新增租戶條件
                if (fromItem instanceof Table) {
                    String queryTable = ((Table) fromItem).getName();
                    TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
                    String tenantColumn = tenantInfo.tablesWithTenantColumn.get(queryTable);
                    if (Objects.nonNull(tenantColumn)) {
                        plainSelect.setWhere(appendTenantCondition(where, fromItem, tenantInfo.tenantId, tenantColumn));
                    }
                }
            }
        } catch (Exception ex) {
            log.error("SQL 處理失敗: {}", statement, ex);
            throw new RuntimeException("SQL 處理失敗", ex);
        }
        return statement;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在別名,則欄位應該變“表別名.欄位名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租戶資訊
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租戶ID
         */
        private final String tenantId;
        /**
         * 要新增租戶條件的表名稱與對應的租戶欄位
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

1.3.複雜 SQL 的解析

在實際場景中,尤其是涉及到手寫 SQL 的場景中,SQL 往往比較複雜,比如:

  • 查詢可能基於一張虛擬表,比如: select * from (selecrt from t1 where t1.id = xx) t2 這種情況。
  • 可能會存在關聯查,比如: select * from t1 left join t2 on t2.id = t1.tid 這種情況。
  • 可能會涉及到子查詢,比如:select * from t where t.id in (select t2.tid from t2 where t2.id = xxx) 這種情況。

除上述這幾種情況外,我們還需要考慮各種組合的場景,比如 union 型別的聯合查詢,函式與子查詢的巢狀,基於虛擬表的聯查……等等。

1.3.1.改進方案

雖然情況有很多種,不過值得高興的是,我們還是有辦法為其歸納出一個處理流程。簡單的來說,就是檢查所有可能存在巢狀查詢的語句,進行遞迴解析:

  1. 第一步,先解析語句本身,如果是 union 這種聯合查詢,則將其拆分為多條單體 SQL 進行遞迴解析;
  2. 第二步,對於單條 SQL,解析其 select 的欄位,如果存在函式或者子查詢,則將每個欄位其作為一個單體 SQL 進行遞迴解析;
  3. 第三步,解析其 from 語句,如果存在函式或者基於子查詢的臨時表,則將子查詢作為一個單體 SQL 進行遞迴解析;
  4. 第四步,解析 join 語句:
    1. 如果 join 的表本身是基於子查詢的臨時表,則將子查詢作為一個單體 SQL 進行遞迴解析;
    2. 如果 on 條件中存在函式或者子查詢,則將其作為單體 SQL 進行遞迴解析;
  5. 第五步,解析 where 條件,如果存在函式或者基於子查詢的條件欄位,則將其作為一個單體 SQL 進行遞迴解析。

基於上述分析,我們需要對現有的程式碼做出一點調整:

  • 在 doHandle 方法中,我們需要判斷 fromItem 的型別,如果是子查詢,則需要進行遞迴處理。
  • 在 doHandle 方法後,我們需要新增一部分對 join 語句的處理,由於 join 語句同樣由 from 和 where 兩部分組成,因此此處的邏輯應當與正常的 select 差不多。
  • 在 appendTenantCondition 方法之前,我們需要增加對特殊條件的處理,對應每個條件,我們都需要檢查是否存在可能的子查詢,如果存則需要進行遞迴處理。

1.3.2.改進後的程式碼

根據改進方案,我們再次調整程式碼:

@Slf4j
public class TenantSQLInterceptor {
    
    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 設定租戶資訊
     *
     * @param tenantInfo 租戶資訊
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租戶資訊
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }
    
    /**
     * 處理SQL語句
     *
     * @param sql SQL語句
     * @return 處理後的SQL語句
     */
    @Nullable
    public String handle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
        } catch (JSQLParserException e) {
            log.error("SQL 解析失敗: {}", sql, e);
            throw new RuntimeException("SQL 解析失敗", e);
        }
        return statements;
    }

    private Statement doHandle(Statement statement) {
        try {
            if (statement instanceof Select) {
                processSelect(((Select) statement).getSelectBody());
            } else if (statement instanceof Update) {
                processUpdate((Update) statement);
            } else if (statement instanceof Delete) {
                processDelete((Delete) statement);
            } else if (statement instanceof Insert) {
                processInsert((Insert) statement);
            }
        } catch (Exception ex) {
            log.error("SQL 處理失敗: {}", statement, ex);
            throw new RuntimeException("SQL 處理失敗", ex);
        }
        return statement;
    }

    private void processSelect(SelectBody selectBody) {
        // 普通查詢
        if (selectBody instanceof PlainSelect) {
            processSelect((PlainSelect) selectBody);
        }
        // 巢狀查詢,比如 select xx from (select yy from t)
        else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelect(withItem.getSelectBody());
            }
        }
        // 聯合查詢,比如 union
        else if (selectBody instanceof SetOperationList) {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (CollUtil.isNotEmpty(operationList.getSelects())) {
                operationList.getSelects().forEach(this::processSelect);
            }
        }
        // 值查詢,比如 select 1, 2, 3
        else if (selectBody instanceof ValuesStatement) {
            List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
            if (CollUtil.isNotEmpty(expressions)) {
                expressions.forEach(exp -> processCondition(exp, null));
            }
        } else {
            log.error("無法解析的 select 語句:{}({})", selectBody, selectBody.getClass());
            throw new RuntimeException("不支援的查詢語句:" + selectBody.getClass().getName()
        }
    }

    /**
     * 處理插入語句
     *
     * @param insert 插入語句
     */
    protected void processInsert(Insert insert) {
        // do nothing
    }

    /**
     * 處理刪除語句
     *
     * @param delete 刪除語句
     */
    protected void processDelete(Delete delete) {
        Table table = delete.getTable();
        delete.setWhere(processCondition(delete.getWhere(), table));
        // 如果還存在關聯查詢
        List<Join> joins = delete.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 處理更新語句
     *
     * @param update 更新語句
     */
    protected void processUpdate(Update update) {
        Table table = update.getTable();
        update.setWhere(processCondition(update.getWhere(), table));
        // 如果還存在關聯查詢
        List<Join> joins = update.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 處理查詢語句
     *
     * @param plainSelect 查詢語句
     */
    protected void processSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        // 如果是普通的表名
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
        }
        // 如果是子查詢,比如 select * from (select xxx from yyy)
        else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelect(subSelect.getSelectBody());
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
        }
        // 如果是帶有特殊函式的子查詢,比如 lateral (select sum(*) from yyy)
        else if (fromItem instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
        }
        // 未知型別的查詢,直接報錯
        else {
            log.error("無法解析的 from 語句:{}({})", fromItem, fromItem.getClass());
            throw new RuntimeException("不支援的查詢語句:" + fromItem.getClass().getName()
        }

        // 如果還存在關聯查詢
        List<Join> joins = plainSelect.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 處理關聯查詢
     *
     * @param join 關聯查詢
     */
    protected void processJoin(Join join) {
        FromItem joinTable = join.getRightItem();
        if (joinTable instanceof Table) {
            Table table = (Table) joinTable;
            join.setOnExpression(processCondition(join.getOnExpression(), table));
        }
        else if (joinTable instanceof SubSelect) {
            processSelect(((SubSelect) joinTable).getSelectBody());
        }
        else if (joinTable instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
        }
        else {
            log.error("無法解析的 join 語句:{}({})", joinTable, joinTable.getClass());
            throw new RuntimeException("不支援的查詢語句:" + joinTable.getClass().getName());
        }
    }

    /**
     * <p>獲取新增了租戶條件的查詢條件,若條件中存在子查詢,則也會為子查詢新增租戶條件。
     *
     * @param expression 條件表示式
     * @param table 表
     * @return 新增租戶條件後的條件表示式
     */
    protected Expression processCondition(@Nullable Expression expression, FromItem table) {
        // 如果已經不可拆分的表示式,則直接返回
        if (isBasicExpression(expression)) {
            return expression;
        }
        // 如果是子查詢,則需要對子查詢進行遞迴處理
        else if (expression instanceof SubSelect) {
            processSelect(((SubSelect) expression).getSelectBody());
        }
        // 如果是 in 條件,比如:xxx in (select xx from yy……),則需要對子查詢進行遞迴處理
        else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelect(((SubSelect) rightItems).getSelectBody());
            }
        }
        // 如果是 not 或者 != 條件,則需要對裡面的條件進行遞迴處理
        else if (expression instanceof NotExpression) {
            NotExpression notExpression = (NotExpression) expression;
            processCondition(notExpression.getExpression(), table);
        }
        // 如果是 (xxx != xxx),則需要對括號裡面的表示式進行遞迴處理
        else if (expression instanceof Parenthesis) {
            Parenthesis parenthesis = (Parenthesis) expression;
            Expression content = parenthesis.getExpression();
            processCondition(content, table);
        }
        // 如果是二元表示式,比如:xx = xx,xx > xx,則需要對左右兩邊的表示式進行遞迴處理
        else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) expression;
            Expression left = binaryExpression.getLeftExpression();
            processCondition(left, table);
            Expression right = binaryExpression.getRightExpression();
            processCondition(right, table);
        }
        // 如果是函式,比如:if(xx, xx) ,則需要對函式的引數進行遞迴處理
        else if (expression instanceof Function) {
            Function function = (Function) expression;
            ExpressionList parameters = function.getParameters();
            if (parameters != null) {
                parameters.getExpressions().forEach(param -> processCondition(param, table));
            }
        }
        // 如果是 case when 語句,則需要對 when 和 then 兩個條件進行遞迴處理
        else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            processCondition(whenClause.getWhenExpression(), table);
            processCondition(whenClause.getThenExpression(), table);
        }
        // 如果是 case 語句,則需要對 switch、when、then、else 四個條件進行遞迴處理
        else if (expression instanceof CaseExpression) {
            CaseExpression caseExpression = (CaseExpression) expression;
            processCondition(caseExpression.getSwitchExpression(), table);
            List<WhenClause> whenClauses = caseExpression.getWhenClauses();
            if (CollUtil.isNotEmpty(whenClauses)) {
                whenClauses.forEach(whenClause -> {
                    processCondition(whenClause.getWhenExpression(), table);
                    processCondition(whenClause.getThenExpression(), table);
                });
            }
            processCondition(caseExpression.getElseExpression(), table);
        }
        // 如果是 exists 語句,比如:exists (select xx from yy……),則需要對子查詢進行遞迴處理
        else if (expression instanceof ExistsExpression) {
            Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
            if (existsExpression instanceof SubSelect) {
                processSelect(((SubSelect) existsExpression).getSelectBody());
            }
        }
        // 如果是 all 或者 any 語句,比如:xx > all (select xx from yy……),則需要對子查詢進行遞迴處理
        else if (expression instanceof AllComparisonExpression) {
            AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
            processSelect(allComparisonExpression.getSubSelect().getSelectBody());
        }
        else if (expression instanceof AnyComparisonExpression) {
            AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
            processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
        }
        // 如果是 cast 語句,比如:cast(xx as xx),則需要對子查詢進行遞迴處理
        else if (expression instanceof CastExpression) {
            CastExpression castExpression = (CastExpression) expression;
            processCondition(castExpression.getLeftExpression(), table);
        }

        // 拼接查詢條件
        Expression appendCondition = handleCondition(expression, table);
        return Objects.isNull(appendCondition) ? expression : appendCondition;
    }

    /**
     * 判斷是否是已經是無法再拆分的基本表示式 <br/>
     * 比如:列名、常量、函式等
     *
     * @param expression 表示式
     * @return 是否是基本表示式
     */
    protected boolean isBasicExpression(@Nullable Expression expression) {
        return expression instanceof Column
            || expression instanceof LongValue
            || expression instanceof StringValue
            || expression instanceof DoubleValue
            || expression instanceof NullValue
            || expression instanceof TimeValue
            || expression instanceof TimestampValue
            || expression instanceof DateValue;
    }

    /**
     * 返回一個查詢條件,該查詢條件將替換{@code table}原有的{@code where}條件
     *
     * @param expression 原有的查詢條件
     * @param table 指定的表
     * @return 查詢條件
     */
    @Nullable
    protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        // 如果是一個標準表名,且改表名在租戶表列表中,則為查詢條件新增租戶條件
        if (!(table instanceof Table)) {
            return null;
        }
        String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
        if (Objects.nonNull(tenantColumn)) {
            return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
        }
        return null;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在別名,則欄位應該變“表別名.欄位名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租戶資訊
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租戶ID
         */
        private final String tenantId;
        /**
         * 要新增租戶條件的表名稱與對應的租戶欄位
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

現在,針對預期的複雜場景,我們再來測試一下:

public static void main(String[] args) {
    Map<String, String> tablesWithTenantColumn = Maps.newHashMap();
    tablesWithTenantColumn.put("t", "tenant_id");
    TenantInfo tenantInfo = new TenantInfo("1", tablesWithTenantColumn);
    TenantSQLInterceptor.setTenantInfo(tenantInfo);

    // 處理包含的複雜子查詢的SQL
    String sql = "select * " +
        "from (select * from t where a = 1) t " +
        "left join (select * from t where b = 2) t2 on t.id = t2.id " +
        "where b in (select * from t where c = 2) and d = 3";
    TenantSQLInterceptor interceptor = new TenantSQLInterceptor();
    String handledSql = interceptor.handle(sql);
    System.out.println(handledSql);
    // 輸出結果:
    // select * 
    // from (select * from t where a = 1 and tenant_id = '1') t 
    // left join (select * from t where b = 2 and tenant_id = '1') t2 on t.id = t2.id 
    // where b in (select * from t where c = 2 and tenant_id = '1') and d = 3
}

完美!

1.4.分離公共程式碼

這個 SQL 攔截器已經可以完美滿足我們的大部分需求了。現在功能已經實現,可以看看程式碼層面有什麼可以最佳化的地方了。
我們再次分析一下上述程式碼,會注意到,上面的解析器其實幹了兩件事情:

  • 解析 SQL,並在遞迴獲取不可再拆分的“根” SQL 後,替換其 where 條件。
  • 將 SQL 的 where 條件替換或追加上租戶條件。

換而言之,第一步的邏輯似乎與“租戶攔截”這個需求無關,它顯然可以抽離為一個獨立的元件以便後續複用。因此,這裡我們將這個新元件根據其功能命名為 AbstractConditionSqlHandler

/**
 * <p>SQL處理器,用於攔截SQL語句並修改其中的查詢條件,
 * 該處理器支援處理巢狀查詢、聯合查詢、關聯查詢等多種查詢方式。
 *
 * @author huangchengxing
 * @see #handle
 * @see #handleCondition
 */
@Setter
@Slf4j
public abstract class AbstractConditionSqlHandler {

    /**
     * 處理SQL語句
     *
     * @param sql SQL語句
     * @return 處理後的SQL語句
     */
    @Nullable
    public String handle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
        } catch (JSQLParserException e) {
            log.error("SQL 解析失敗: {}", sql, e);
            throw new RuntimeException("SQL 解析失敗");
        }
        return statements;
    }

    private Statement doHandle(Statement statement) {
        try {
            if (statement instanceof Select) {
                processSelect(((Select) statement).getSelectBody());
            } else if (statement instanceof Update) {
                processUpdate((Update) statement);
            } else if (statement instanceof Delete) {
                processDelete((Delete) statement);
            } else if (statement instanceof Insert) {
                processInsert((Insert) statement);
            }
        } catch (Exception ex) {
            log.error("SQL 處理失敗: {}", statement, ex);
            throw new RuntimeException("SQL 處理失敗");
        }
        return statement;
    }

    private void processSelect(SelectBody selectBody) {
        // 普通查詢
        if (selectBody instanceof PlainSelect) {
            processSelect((PlainSelect) selectBody);
        }
        // 巢狀查詢,比如 select xx from (select yy from t)
        else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelect(withItem.getSelectBody());
            }
        }
        // 聯合查詢,比如 union
        else if (selectBody instanceof SetOperationList) {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (CollUtil.isNotEmpty(operationList.getSelects())) {
                operationList.getSelects().forEach(this::processSelect);
            }
        }
        // 值查詢,比如 select 1, 2, 3
        else if (selectBody instanceof ValuesStatement) {
            List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
            if (CollUtil.isNotEmpty(expressions)) {
                expressions.forEach(exp -> processCondition(exp, null));
            }
        } else {
            log.error("無法解析的 select 語句:{}({})", selectBody, selectBody.getClass());
            throw new RuntimeException("不支援的查詢語句:" + selectBody.getClass().getName());
        }
    }

    /**
     * 處理插入語句
     *
     * @param insert 插入語句
     */
    protected void processInsert(Insert insert) {
        // do nothing
    }

    /**
     * 處理刪除語句
     *
     * @param delete 刪除語句
     */
    protected void processDelete(Delete delete) {
        Table table = delete.getTable();
        delete.setWhere(processCondition(delete.getWhere(), table));
        // 如果還存在關聯查詢
        List<Join> joins = delete.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 處理更新語句
     *
     * @param update 更新語句
     */
    protected void processUpdate(Update update) {
        Table table = update.getTable();
        update.setWhere(processCondition(update.getWhere(), table));
        // 如果還存在關聯查詢
        List<Join> joins = update.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 處理查詢語句
     *
     * @param plainSelect 查詢語句
     */
    protected void processSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        // 如果是普通的表名
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
        }
        // 如果是子查詢,比如 select * from (select xxx from yyy)
        else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelect(subSelect.getSelectBody());
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
        }
        // 如果是帶有特殊函式的子查詢,比如 lateral (select sum(*) from yyy)
        else if (fromItem instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
        }
        // 未知型別的查詢,直接報錯
        else {
            log.error("無法解析的 from 語句:{}({})", fromItem, fromItem.getClass());
            throw new RuntimeException("不支援的查詢語句:" + fromItem.getClass().getName());
        }

        // 如果還存在關聯查詢
        List<Join> joins = plainSelect.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 處理關聯查詢
     *
     * @param join 關聯查詢
     */
    protected void processJoin(Join join) {
        FromItem joinTable = join.getRightItem();
        if (joinTable instanceof Table) {
            Table table = (Table) joinTable;
            join.setOnExpression(processCondition(join.getOnExpression(), table));
        }
        else if (joinTable instanceof SubSelect) {
            processSelect(((SubSelect) joinTable).getSelectBody());
        }
        else if (joinTable instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
        }
        else {
            log.error("無法解析的 join 語句:{}({})", joinTable, joinTable.getClass());
            throw new RuntimeException("不支援的查詢語句:" + joinTable.getClass().getName());
        }
    }

    /**
     * <p>獲取新增了租戶條件的查詢條件,若條件中存在子查詢,則也會為子查詢新增租戶條件。
     *
     * @param expression 條件表示式
     * @param table 表
     * @return 新增租戶條件後的條件表示式
     */
    @SuppressWarnings({"java:S6541", "java:S3776"})
    protected Expression processCondition(@Nullable Expression expression, FromItem table) {
        // 如果已經不可拆分的表示式,則直接返回
        if (isBasicExpression(expression)) {
            return expression;
        }
        // 如果是子查詢,則需要對子查詢進行遞迴處理
        else if (expression instanceof SubSelect) {
            processSelect(((SubSelect) expression).getSelectBody());
        }
        // 如果是 in 條件,比如:xxx in (select xx from yy……),則需要對子查詢進行遞迴處理
        else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelect(((SubSelect) rightItems).getSelectBody());
            }
        }
        // 如果是 not 或者 != 條件,則需要對裡面的條件進行遞迴處理
        else if (expression instanceof NotExpression) {
            NotExpression notExpression = (NotExpression) expression;
            processCondition(notExpression.getExpression(), table);
        }
        // 如果是 (xxx != xxx),則需要對括號裡面的表示式進行遞迴處理
        else if (expression instanceof Parenthesis) {
            Parenthesis parenthesis = (Parenthesis) expression;
            Expression content = parenthesis.getExpression();
            processCondition(content, table);
        }
        // 如果是二元表示式,比如:xx = xx,xx > xx,則需要對左右兩邊的表示式進行遞迴處理
        else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) expression;
            Expression left = binaryExpression.getLeftExpression();
            processCondition(left, table);
            Expression right = binaryExpression.getRightExpression();
            processCondition(right, table);
        }
        // 如果是函式,比如:if(xx, xx) ,則需要對函式的引數進行遞迴處理
        else if (expression instanceof Function) {
            Function function = (Function) expression;
            ExpressionList parameters = function.getParameters();
            if (parameters != null) {
                parameters.getExpressions().forEach(param -> processCondition(param, table));
            }
        }
        // 如果是 case when 語句,則需要對 when 和 then 兩個條件進行遞迴處理
        else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            processCondition(whenClause.getWhenExpression(), table);
            processCondition(whenClause.getThenExpression(), table);
        }
        // 如果是 case 語句,則需要對 switch、when、then、else 四個條件進行遞迴處理
        else if (expression instanceof CaseExpression) {
            CaseExpression caseExpression = (CaseExpression) expression;
            processCondition(caseExpression.getSwitchExpression(), table);
            List<WhenClause> whenClauses = caseExpression.getWhenClauses();
            if (CollUtil.isNotEmpty(whenClauses)) {
                whenClauses.forEach(whenClause -> {
                    processCondition(whenClause.getWhenExpression(), table);
                    processCondition(whenClause.getThenExpression(), table);
                });
            }
            processCondition(caseExpression.getElseExpression(), table);
        }
        // 如果是 exists 語句,比如:exists (select xx from yy……),則需要對子查詢進行遞迴處理
        else if (expression instanceof ExistsExpression) {
            Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
            if (existsExpression instanceof SubSelect) {
                processSelect(((SubSelect) existsExpression).getSelectBody());
            }
        }
        // 如果是 all 或者 any 語句,比如:xx > all (select xx from yy……),則需要對子查詢進行遞迴處理
        else if (expression instanceof AllComparisonExpression) {
            AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
            processSelect(allComparisonExpression.getSubSelect().getSelectBody());
        }
        else if (expression instanceof AnyComparisonExpression) {
            AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
            processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
        }
        // 如果是 cast 語句,比如:cast(xx as xx),則需要對子查詢進行遞迴處理
        else if (expression instanceof CastExpression) {
            CastExpression castExpression = (CastExpression) expression;
            processCondition(castExpression.getLeftExpression(), table);
        }

        // 拼接查詢條件
        Expression appendCondition = handleCondition(expression, table);
        return Objects.isNull(appendCondition) ? expression : appendCondition;
    }

    /**
     * 返回一個查詢條件,該查詢條件將替換{@code table}原有的{@code where}條件
     *
     * @param expression 原有的查詢條件
     * @param table 指定的表
     * @return 查詢條件
     */
    protected abstract Expression handleCondition(@Nullable Expression expression, FromItem table);

    /**
     * 判斷是否是已經是無法再拆分的基本表示式 <br/>
     * 比如:列名、常量、函式等
     *
     * @param expression 表示式
     * @return 是否是基本表示式
     */
    protected boolean isBasicExpression(@Nullable Expression expression) {
        return expression instanceof Column
            || expression instanceof LongValue
            || expression instanceof StringValue
            || expression instanceof DoubleValue
            || expression instanceof NullValue
            || expression instanceof TimeValue
            || expression instanceof TimestampValue
            || expression instanceof DateValue;
    }
}

接著,對於原本的 SQL 攔截器,我們令其繼承 AbstractConditionSqlHandler,然後更換一個更合適的名字 ContextTenantConditionSqlHandler

/**
 * SQL攔截器,用於為SQL語句新增租戶條件。
 * 每次執行SQL時,將會檢查當前執行緒上下文中是否存在租戶資訊,如果存在,則會為查詢語句新增租戶條件,否則直接略過。
 *
 * @author huangchengxing
 * @see ContextTenantConditionSqlHandlerAdvisor
 */
@Slf4j
public class ContextTenantConditionSqlHandler extends AbstractConditionSqlHandler {

    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 設定租戶資訊
     *
     * @param tenantInfo 租戶資訊
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租戶資訊
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }

    @Override
    public String handle(String sql) {
        // 如果未設定租戶資訊,則直接返回原始SQL
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        if (Objects.isNull(tenantInfo)) {
            return sql;
        }
        log.debug("租戶攔截器攔截原始 SQL: {}", sql);
        String handledSql = super.handle(sql);
        log.info("租戶攔截器攔截後 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    @Override
    @Nullable
    protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        // 如果是一個標準表名,且改表名在租戶表列表中,則為查詢條件新增租戶條件
        if (!(table instanceof Table)) {
            return null;
        }
        String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
        if (Objects.nonNull(tenantColumn)) {
            return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
        }
        return null;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在別名,則欄位應該變“表別名.欄位名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租戶資訊
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租戶ID
         */
        private final String tenantId;
        /**
         * 要新增租戶條件的表名稱與對應的租戶欄位
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

1.5.與 JPA 結合使用

JPA 的預設實現 Hibernate 提供了 StatementInspector 介面,我們實現一個自定義的實現類,然後讓基礎上文實現好的租戶解析器即可 ContextTenantConditionSqlHandler

/**
 * SQL攔截器,用於為SQL語句新增租戶條件。
 * 每次執行SQL時,將會檢查當前執行緒中是否存在租戶資訊,如果存在,則會為查詢語句新增租戶條件,否則直接略過。
 *
 * @author huangchengxing
 */
@Slf4j
@RequiredArgsConstructor
public class HibernateTenantStatementInspector
    extends ContextTenantConditionSqlHandler implements StatementInspector {

    @Override
    public String inspect(String sql) {
        return handle(sql);
    }
}

同理,我們也可以結合 Mybatis 或其他的框架實現類似的效果。

2.租戶攔截器

顯然,我們不可能無條件的攔截所有的查詢,有些查詢本身不需要進行攔截,而有些查詢當訪問者為管理員時也不需要攔截……總而言之,對應租戶攔截,我們需要採用白名單而不是黑名單的方式,因此最好的實現方法就是搞一個切面,然後只對帶有特定註解的方法的呼叫進行攔截。

2.1.註解類

我們定義一個 @TenantOperation 註解,該註解可以被用於方法或者類上,當用於類上的時候等於類中所有的方法都應用攔截:

/**
 * 表明方法是一個租戶操作方法,需要在相關的SQL中加入租戶過濾條件
 *
 * @author huangchengxing
 * @see ContextTenantConditionSqlHandlerAdvisor
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
public @interface TenantOperation {

    /**
     * 表配置
     *
     * @return 表配置
     */
    Tables[] value() default {};

    /**
     * 是否對當前方法與後續呼叫鏈不進行租戶攔截
     *
     * @return boolean
     * @see Ignore
     */
    boolean ignore() default false;

    /**
     * 對當前方法與後續呼叫鏈不進行租戶攔截
     */
    @TenantOperation(ignore = true) // 基於 Springt 合成註解機制的擴充套件註解
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE})
    @interface Ignore {}

    /**
     * 表配置
     */
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @interface Tables {

        /**
         * 租戶欄位名,不指定時預設遵循配置檔案中的欄位名
         *
         * @return String
         */
        String column() default "";

        /**
         * 需要新增過濾條件的表名,不指定時預設遵循配置檔案中的表名
         *
         * @return String
         */
        String[] tables() default {};
    }
}

此外,為了便於使用,註解還支援直接指定要攔截的表和欄位,以便覆蓋預設配置檔案中的配置。

2.2.方法攔截器

為了便於後續擴充套件,這裡筆者沒有基於 Aspect 註解,而是基於 Spring 的方法攔截器,自定義了切點來實現這個效果:

/**
 * 方法攔截器,用於攔截帶有{@link TenantOperation}註解的方法,為涉及的查詢語句新增租戶過濾條件
 *
 * @author huangchengxing
 * @see ContextTenantConditionSqlHandler
 */
@Slf4j
public class ContextTenantConditionSqlHandlerAdvisor implements PointcutAdvisor, MethodInterceptor {

    private static final String INTERCEPT_REQUEST_ENTRY = "tenant";
    private static final TenantOpsInfo NULL = new TenantOpsInfo(null);
    private final Map<Method, TenantOpsInfo> tenantInfoCaches = new ConcurrentReferenceHashMap<>();
    private final TenantOpsInfo opsByDefault;

    public ContextTenantConditionSqlHandlerAdvisor(Map<String, String> tableWithColumns) {
        this.opsByDefault = new TenantOpsInfo(tableWithColumns);
    }

    @Override
    public Object invoke(MethodInvocation methodInvocation) throws Throwable {
        // 從上下文獲取租戶ID
        String tenantId = Optional.ofNullable(RequestUserContext.getUser())
            .map(RequestUserContext.User::getUserId)
            .orElse(null);
        // 若沒有上下文資訊,則直接放行
        if (Objects.isNull(tenantId)) {
            return methodInvocation.proceed();
        }

        // 解析配置資訊
        TenantOpsInfo info = resolveMethod(methodInvocation.getMethod());
        if (info == NULL) {
            return methodInvocation.proceed();
        }

        // 設定租戶資訊
        try {
            ContextTenantConditionSqlHandler.setTenantInfo(info.getTenantInfo(tenantId));
            return methodInvocation.proceed();
        } finally {
            ContextTenantConditionSqlHandler.clearTenantInfo();
        }
    }

    private TenantOpsInfo resolveMethod(Method method) {
        return tenantInfoCaches.computeIfAbsent(method, m -> {
            // 從方法上或類上獲取註解
            TenantOperation annotation = Optional.ofNullable(AnnotatedElementUtils.findMergedAnnotation(method, TenantOperation.class))
                .orElse(AnnotatedElementUtils.findMergedAnnotation(method.getDeclaringClass(), TenantOperation.class));
            if (Objects.isNull(annotation)) {
                return NULL;
            }
            // 若註解未指定column和tables,則使用預設值
            TenantOperation.Tables[] tables = annotation.value();
            if (ArrayUtil.isEmpty(tables)) {
                return opsByDefault;
            }
            // 若指定了column和tables,則使用指定值
            Map<String, String> tableWithColumns = new HashMap<>(tables.length);
            for (TenantOperation.Tables table : tables) {
                String column = table.column();
                for (String tableName : table.tables()) {
                    tableWithColumns.put(tableName, column);
                }
            }
            return new TenantOpsInfo(tableWithColumns);
        });
    }

    @RequiredArgsConstructor
    private static class TenantOpsInfo {
        private final Map<String, String> tablesWithTenantColumn;
        public ContextTenantConditionSqlHandler.TenantInfo getTenantInfo(String tenantId) {
            return new ContextTenantConditionSqlHandler.TenantInfo(tenantId, tablesWithTenantColumn);
        }
    }

    @Override
    public @NonNull Pointcut getPointcut() {
        return TenantQueryPointcut.INSTANCE;
    }

    @Override
    public @NonNull Advice getAdvice() {
        return this;
    }

    @Override
    public boolean isPerInstance() {
        return false;
    }

    // 自定義切點,攔截帶有 @TenantOperation 註解的方法,或宣告類上帶有 @TenantOperation 註解的全部方法
    private static class TenantQueryPointcut extends StaticMethodMatcher implements Pointcut {
        public static final TenantQueryPointcut INSTANCE = new TenantQueryPointcut();
        @Override
        public @NonNull ClassFilter getClassFilter() {
            return ClassFilter.TRUE;
        }
        @Override
        public @NonNull MethodMatcher getMethodMatcher() {
            return this;
        }
        @Override
        public boolean matches(@NonNull Method method, @NonNull Class<?> type) {
            return AnnotatedElementUtils.isAnnotated(method, TenantOperation.class)
                || AnnotatedElementUtils.isAnnotated(type, TenantOperation.class);
        }
    }
}

3.使用

3.1.配置類

首先,我們先定義一個配置類以在專案中啟用上述元件:

/**
 * <p>租戶攔截器配置,啟用後可以為指定的查詢方法新增租戶過濾條件。 <br/>
 * 可透過配置檔案進行配置:<br/>
 * <pre>
 * # JPA 啟用租戶 SQL 攔截器
 * spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
 * # 啟用租戶攔截器
 * tenant.interceptor.enabled=true
 * # 需要攔截的表
 * tenant.interceptor.tables[0].column = tenant_id
 * tenant.interceptor.tables[0].tableNames = table1, table2
 * </pre>
 *
 * @author huangchengxing
 */
@Slf4j
@ConditionalOnProperty(prefix = TenantInterceptorConfig.Properties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@EnableConfigurationProperties(TenantInterceptorConfig.Properties.class)
@Configuration
public class TenantInterceptorConfig {

    @Bean
    public ContextTenantConditionSqlHandlerAdvisor tenantQueryAdvisor(Properties properties) {
        log.info("啟用租戶攔截器,需要攔截的表:{}", properties.getTables());
        Map<String, String> tableWithColumns = new HashMap<>(16);
        properties.getTables().forEach(ts -> ts.getTableNames().forEach(t -> {
            Assert.isFalse(tableWithColumns.containsKey(t), "同一張表具備只允許具備一個租戶欄位:{}", t);
            tableWithColumns.put(t, ts.getColumn());
        }));
        return new ContextTenantConditionSqlHandlerAdvisor(tableWithColumns);
    }

    /**
     * @author huangchengxing
     */
    @ConfigurationProperties(prefix = Properties.CONFIG_PREFIX)
    @Data
    public static class Properties {

        public static final String CONFIG_PREFIX = "tenant.interceptor";

        /**
         * 表配置
         */
        private List<Tables> tables = new ArrayList<>();

        @Data
        public static class Tables {

            /**
             * 租戶欄位名
             */
            private String column;

            /**
             * 需要攔截的表名
             */
            private Set<String> tableNames;
        }
    }
}

3.2.配置檔案

隨後在配置檔案中啟用攔截器,並配置好要攔截的表:

# 啟用租戶 SQL 攔截器
spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
# 啟用租戶攔截器
tenant.interceptor.enabled=true
# 攔截 t_user, t_resource, t_assest 表中的 tenant_id 欄位
tenant.interceptor.tables[0].column=tenant_id
tenant.interceptor.tables[0].table-names=t_user, t_resource, t_assest

3.3.新增註解

最後,我們只要在對應的類或者方法上新增 @TenantOperation 即可:

@TenantOperation // 預設所有方法都要應用攔截
@RestController
public class ResourceController {

    // @TenantOperation 因為類上已經加了,所以方法上可以不用加
    @GetMapping
    public List<Resource> listResource1(List<Integer> ids) {
        // do something
    }

    @TenantOperation.Ingore // 該方法不進行攔截
    @GetMapping
    public List<Resource> listResource2(List<Integer> ids) {
        // do something
    }
}

相關文章