Search in sources :

Example 16 with SQLUpdateStatement

use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.

the class SQLUpdateBuilderImpl method set.

public SQLUpdateBuilderImpl set(String... items) {
    SQLUpdateStatement update = getSQLUpdateStatement();
    for (String item : items) {
        SQLUpdateSetItem updateSetItem = SQLUtils.toUpdateSetItem(item, dbType);
        update.addItem(updateSetItem);
    }
    return this;
}
Also used : SQLUpdateSetItem(com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem) SQLUpdateStatement(com.alibaba.druid.sql.ast.statement.SQLUpdateStatement)

Example 17 with SQLUpdateStatement

use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.

the class WallProvider method checkInternal.

private WallCheckResult checkInternal(String sql) {
    checkCount.incrementAndGet();
    WallContext context = WallContext.current();
    if (config.isDoPrivilegedAllow() && ispPrivileged()) {
        WallCheckResult checkResult = new WallCheckResult();
        checkResult.setSql(sql);
        return checkResult;
    }
    // first step, check whiteList
    boolean mulltiTenant = config.getTenantTablePattern() != null && config.getTenantTablePattern().length() > 0;
    if (!mulltiTenant) {
        WallCheckResult checkResult = checkWhiteAndBlackList(sql);
        if (checkResult != null) {
            checkResult.setSql(sql);
            return checkResult;
        }
    }
    hardCheckCount.incrementAndGet();
    final List<Violation> violations = new ArrayList<Violation>();
    List<SQLStatement> statementList = new ArrayList<SQLStatement>();
    boolean syntaxError = false;
    boolean endOfComment = false;
    try {
        SQLStatementParser parser = createParser(sql);
        parser.getLexer().setCommentHandler(WallCommentHandler.instance);
        if (!config.isCommentAllow()) {
            // deny comment
            parser.getLexer().setAllowComment(false);
        }
        if (!config.isCompleteInsertValuesCheck()) {
            parser.setParseCompleteValues(false);
            parser.setParseValuesSize(config.getInsertValuesCheckSize());
        }
        parser.parseStatementList(statementList);
        final Token lastToken = parser.getLexer().token();
        if (lastToken != Token.EOF && config.isStrictSyntaxCheck()) {
            violations.add(new IllegalSQLObjectViolation(ErrorCode.SYNTAX_ERROR, "not terminal sql, token " + lastToken, sql));
        }
        endOfComment = parser.getLexer().isEndOfComment();
    } catch (NotAllowCommentException e) {
        violations.add(new IllegalSQLObjectViolation(ErrorCode.COMMENT_STATEMENT_NOT_ALLOW, "comment not allow", sql));
        incrementCommentDeniedCount();
    } catch (ParserException e) {
        syntaxErrorCount.incrementAndGet();
        syntaxError = true;
        if (config.isStrictSyntaxCheck()) {
            violations.add(new SyntaxErrorViolation(e, sql));
        }
    } catch (Exception e) {
        if (config.isStrictSyntaxCheck()) {
            violations.add(new SyntaxErrorViolation(e, sql));
        }
    }
    if (statementList.size() > 1 && !config.isMultiStatementAllow()) {
        violations.add(new IllegalSQLObjectViolation(ErrorCode.MULTI_STATEMENT, "multi-statement not allow", sql));
    }
    WallVisitor visitor = createWallVisitor();
    visitor.setSqlEndOfComment(endOfComment);
    if (statementList.size() > 0) {
        boolean lastIsHint = false;
        for (int i = 0; i < statementList.size(); i++) {
            SQLStatement stmt = statementList.get(i);
            if ((i == 0 || lastIsHint) && stmt instanceof MySqlHintStatement) {
                lastIsHint = true;
                continue;
            }
            try {
                stmt.accept(visitor);
            } catch (ParserException e) {
                violations.add(new SyntaxErrorViolation(e, sql));
            }
        }
    }
    if (visitor.getViolations().size() > 0) {
        violations.addAll(visitor.getViolations());
    }
    Map<String, WallSqlTableStat> tableStat = context.getTableStats();
    boolean updateCheckHandlerEnable = false;
    {
        WallUpdateCheckHandler updateCheckHandler = config.getUpdateCheckHandler();
        if (updateCheckHandler != null) {
            for (SQLStatement stmt : statementList) {
                if (stmt instanceof SQLUpdateStatement) {
                    SQLUpdateStatement updateStmt = (SQLUpdateStatement) stmt;
                    SQLName table = updateStmt.getTableName();
                    if (table != null) {
                        String tableName = table.getSimpleName();
                        Set<String> updateCheckColumns = config.getUpdateCheckTable(tableName);
                        if (updateCheckColumns != null && updateCheckColumns.size() > 0) {
                            updateCheckHandlerEnable = true;
                            break;
                        }
                    }
                }
            }
        }
    }
    WallSqlStat sqlStat = null;
    if (violations.size() > 0) {
        violationCount.incrementAndGet();
        if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
            sqlStat = addBlackSql(sql, tableStat, context.getFunctionStats(), violations, syntaxError);
        }
    } else {
        if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
            boolean selectLimit = false;
            if (config.getSelectLimit() > 0) {
                for (SQLStatement stmt : statementList) {
                    if (stmt instanceof SQLSelectStatement) {
                        selectLimit = true;
                        break;
                    }
                }
            }
            if (!selectLimit) {
                sqlStat = addWhiteSql(sql, tableStat, context.getFunctionStats(), syntaxError);
            }
        }
    }
    if (sqlStat == null && updateCheckHandlerEnable) {
        sqlStat = new WallSqlStat(tableStat, context.getFunctionStats(), violations, syntaxError);
    }
    Map<String, WallSqlTableStat> tableStats = null;
    Map<String, WallSqlFunctionStat> functionStats = null;
    if (context != null) {
        tableStats = context.getTableStats();
        functionStats = context.getFunctionStats();
        recordStats(tableStats, functionStats);
    }
    WallCheckResult result;
    if (sqlStat != null) {
        context.setSqlStat(sqlStat);
        result = new WallCheckResult(sqlStat, statementList);
    } else {
        result = new WallCheckResult(null, violations, tableStats, functionStats, statementList, syntaxError);
    }
    String resultSql;
    if (visitor.isSqlModified()) {
        resultSql = SQLUtils.toSQLString(statementList, dbType);
    } else {
        resultSql = sql;
    }
    result.setSql(resultSql);
    result.setUpdateCheckItems(visitor.getUpdateCheckItems());
    return result;
}
Also used : SyntaxErrorViolation(com.alibaba.druid.wall.violation.SyntaxErrorViolation) IllegalSQLObjectViolation(com.alibaba.druid.wall.violation.IllegalSQLObjectViolation) HashSet(java.util.HashSet) Set(java.util.Set) SyntaxErrorViolation(com.alibaba.druid.wall.violation.SyntaxErrorViolation) ArrayList(java.util.ArrayList) IllegalSQLObjectViolation(com.alibaba.druid.wall.violation.IllegalSQLObjectViolation) Token(com.alibaba.druid.sql.parser.Token) SQLStatement(com.alibaba.druid.sql.ast.SQLStatement) NotAllowCommentException(com.alibaba.druid.sql.parser.NotAllowCommentException) ParserException(com.alibaba.druid.sql.parser.ParserException) SQLStatementParser(com.alibaba.druid.sql.parser.SQLStatementParser) SQLUpdateStatement(com.alibaba.druid.sql.ast.statement.SQLUpdateStatement) SQLName(com.alibaba.druid.sql.ast.SQLName) MySqlHintStatement(com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlHintStatement) ParserException(com.alibaba.druid.sql.parser.ParserException) NotAllowCommentException(com.alibaba.druid.sql.parser.NotAllowCommentException) SQLSelectStatement(com.alibaba.druid.sql.ast.statement.SQLSelectStatement)

Example 18 with SQLUpdateStatement

use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.

the class SQLASTVisitorAdapterTest method test_adapter.

public void test_adapter() throws Exception {
    SQLASTVisitorAdapter adapter = new SQLASTVisitorAdapter();
    new SQLBinaryOpExpr().accept(adapter);
    new SQLInListExpr().accept(adapter);
    new SQLSelectQueryBlock().accept(adapter);
    new SQLDropTableStatement().accept(adapter);
    new SQLCreateTableStatement().accept(adapter);
    new SQLDeleteStatement().accept(adapter);
    new SQLCurrentOfCursorExpr().accept(adapter);
    new SQLInsertStatement().accept(adapter);
    new SQLUpdateStatement().accept(adapter);
    new SQLNotNullConstraint().accept(adapter);
    new SQLMethodInvokeExpr().accept(adapter);
    new SQLCallStatement().accept(adapter);
    new SQLSomeExpr().accept(adapter);
    new SQLAnyExpr().accept(adapter);
    new SQLAllExpr().accept(adapter);
    new SQLDefaultExpr().accept(adapter);
    new SQLCommentStatement().accept(adapter);
    new SQLDropViewStatement().accept(adapter);
    new SQLSavePointStatement().accept(adapter);
    new SQLReleaseSavePointStatement().accept(adapter);
    new SQLCreateDatabaseStatement().accept(adapter);
    new SQLAlterTableDropIndex().accept(adapter);
    new SQLOver().accept(adapter);
    new SQLWithSubqueryClause().accept(adapter);
    new SQLAlterTableAlterColumn().accept(adapter);
    new SQLAlterTableStatement().accept(adapter);
    new SQLAlterTableDisableConstraint().accept(adapter);
    new SQLAlterTableEnableConstraint().accept(adapter);
    new SQLColumnCheck().accept(adapter);
    new SQLExprHint().accept(adapter);
    new SQLAlterTableDropConstraint().accept(adapter);
}
Also used : SQLASTVisitorAdapter(com.alibaba.druid.sql.visitor.SQLASTVisitorAdapter) SQLCreateTableStatement(com.alibaba.druid.sql.ast.statement.SQLCreateTableStatement) SQLMethodInvokeExpr(com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr) SQLAlterTableDropIndex(com.alibaba.druid.sql.ast.statement.SQLAlterTableDropIndex) SQLAlterTableStatement(com.alibaba.druid.sql.ast.statement.SQLAlterTableStatement) SQLAllExpr(com.alibaba.druid.sql.ast.expr.SQLAllExpr) SQLReleaseSavePointStatement(com.alibaba.druid.sql.ast.statement.SQLReleaseSavePointStatement) SQLNotNullConstraint(com.alibaba.druid.sql.ast.statement.SQLNotNullConstraint) SQLOver(com.alibaba.druid.sql.ast.SQLOver) SQLDropTableStatement(com.alibaba.druid.sql.ast.statement.SQLDropTableStatement) SQLCreateDatabaseStatement(com.alibaba.druid.sql.ast.statement.SQLCreateDatabaseStatement) SQLDropViewStatement(com.alibaba.druid.sql.ast.statement.SQLDropViewStatement) SQLInsertStatement(com.alibaba.druid.sql.ast.statement.SQLInsertStatement) SQLBinaryOpExpr(com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr) SQLAlterTableDisableConstraint(com.alibaba.druid.sql.ast.statement.SQLAlterTableDisableConstraint) SQLSavePointStatement(com.alibaba.druid.sql.ast.statement.SQLSavePointStatement) SQLAlterTableDropConstraint(com.alibaba.druid.sql.ast.statement.SQLAlterTableDropConstraint) SQLUpdateStatement(com.alibaba.druid.sql.ast.statement.SQLUpdateStatement) SQLCallStatement(com.alibaba.druid.sql.ast.statement.SQLCallStatement) SQLSomeExpr(com.alibaba.druid.sql.ast.expr.SQLSomeExpr) SQLInListExpr(com.alibaba.druid.sql.ast.expr.SQLInListExpr) SQLExprHint(com.alibaba.druid.sql.ast.statement.SQLExprHint) SQLWithSubqueryClause(com.alibaba.druid.sql.ast.statement.SQLWithSubqueryClause) SQLDeleteStatement(com.alibaba.druid.sql.ast.statement.SQLDeleteStatement) SQLCurrentOfCursorExpr(com.alibaba.druid.sql.ast.expr.SQLCurrentOfCursorExpr) SQLCommentStatement(com.alibaba.druid.sql.ast.statement.SQLCommentStatement) SQLColumnCheck(com.alibaba.druid.sql.ast.statement.SQLColumnCheck) SQLAnyExpr(com.alibaba.druid.sql.ast.expr.SQLAnyExpr) SQLAlterTableAlterColumn(com.alibaba.druid.sql.ast.statement.SQLAlterTableAlterColumn) SQLSelectQueryBlock(com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock) SQLDefaultExpr(com.alibaba.druid.sql.ast.expr.SQLDefaultExpr) SQLAlterTableEnableConstraint(com.alibaba.druid.sql.ast.statement.SQLAlterTableEnableConstraint)

Example 19 with SQLUpdateStatement

use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.

the class MySqlUpdateTest_15 method test_0.

public void test_0() throws Exception {
    String sql = "update students set name='test' where id in (select stu_id from score where s <100)";
    MySqlStatementParser parser = new MySqlStatementParser(sql);
    List<SQLStatement> statementList = parser.parseStatementList();
    SQLStatement stmt = statementList.get(0);
    print(statementList);
    assertEquals(1, statementList.size());
    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    stmt.accept(visitor);
    // System.out.println("Tables : " + visitor.getTables());
    // System.out.println("fields : " + visitor.getColumns());
    // System.out.println("coditions : " + visitor.getConditions());
    // System.out.println("orderBy : " + visitor.getOrderByColumns());
    assertEquals(2, visitor.getTables().size());
    assertEquals(4, visitor.getColumns().size());
    // assertEquals(2, visitor.getConditions().size());
    assertTrue(visitor.containsTable("students"));
    assertTrue(visitor.containsTable("score"));
    assertTrue(visitor.getColumns().contains(new Column("students", "name")));
    assertTrue(visitor.getColumns().contains(new Column("score", "stu_id")));
    {
        String output = SQLUtils.toMySqlString(stmt);
        assertEquals("UPDATE students\n" + "SET name = 'test'\n" + "WHERE id IN (\n" + "\t\tSELECT stu_id\n" + "\t\tFROM score\n" + "\t\tWHERE s < 100\n" + // 
        "\t)", output);
    }
    {
        String output = SQLUtils.toMySqlString(stmt, SQLUtils.DEFAULT_LCASE_FORMAT_OPTION);
        assertEquals("update students\n" + "set name = 'test'\n" + "where id in (\n" + "\t\tselect stu_id\n" + "\t\tfrom score\n" + "\t\twhere s < 100\n" + // 
        "\t)", output);
    }
    assertTrue(WallUtils.isValidateMySql(sql));
    {
        SQLUpdateStatement update = (SQLUpdateStatement) stmt;
        SQLExpr where = update.getWhere();
        assertEquals("id IN (\n" + "\tSELECT stu_id\n" + "\tFROM score\n" + "\tWHERE s < 100\n" + ")", where.toString());
    }
}
Also used : MySqlSchemaStatVisitor(com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor) Column(com.alibaba.druid.stat.TableStat.Column) SQLUpdateStatement(com.alibaba.druid.sql.ast.statement.SQLUpdateStatement) MySqlStatementParser(com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser) SQLStatement(com.alibaba.druid.sql.ast.SQLStatement) SQLExpr(com.alibaba.druid.sql.ast.SQLExpr)

Example 20 with SQLUpdateStatement

use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.

the class MySqlUpdateTest_17 method test_0.

public void test_0() throws Exception {
    String sql = "update security_group_ip_count set ip_count=GREATEST(ip_count-?, 0), gmt_modified=now() where group_id=? ";
    MySqlStatementParser parser = new MySqlStatementParser(sql);
    List<SQLStatement> statementList = parser.parseStatementList();
    SQLStatement stmt = statementList.get(0);
    print(statementList);
    assertEquals(1, statementList.size());
    MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
    stmt.accept(visitor);
    // System.out.println("Tables : " + visitor.getTables());
    // System.out.println("fields : " + visitor.getColumns());
    // System.out.println("coditions : " + visitor.getConditions());
    // System.out.println("orderBy : " + visitor.getOrderByColumns());
    assertEquals(1, visitor.getTables().size());
    assertEquals(3, visitor.getColumns().size());
    // assertEquals(2, visitor.getConditions().size());
    assertTrue(visitor.containsTable("security_group_ip_count"));
    assertTrue(visitor.getColumns().contains(new Column("security_group_ip_count", "ip_count")));
    assertEquals("UPDATE security_group_ip_count\n" + "SET ip_count = GREATEST(ip_count - ?, 0), gmt_modified = now()\n" + // 
    "WHERE group_id = ?", stmt.toString());
    assertEquals("update security_group_ip_count\n" + "set ip_count = GREATEST(ip_count - ?, 0), gmt_modified = now()\n" + // 
    "where group_id = ?", stmt.toLowerCaseString());
    assertTrue(WallUtils.isValidateMySql(sql));
    {
        SQLUpdateStatement update = (SQLUpdateStatement) stmt;
        SQLExpr where = update.getWhere();
        assertEquals("group_id = ?", where.toString());
    }
}
Also used : MySqlSchemaStatVisitor(com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor) Column(com.alibaba.druid.stat.TableStat.Column) SQLUpdateStatement(com.alibaba.druid.sql.ast.statement.SQLUpdateStatement) MySqlStatementParser(com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser) SQLStatement(com.alibaba.druid.sql.ast.SQLStatement) SQLExpr(com.alibaba.druid.sql.ast.SQLExpr)

Aggregations

SQLUpdateStatement (com.alibaba.druid.sql.ast.statement.SQLUpdateStatement)20 SQLExpr (com.alibaba.druid.sql.ast.SQLExpr)8 SQLStatement (com.alibaba.druid.sql.ast.SQLStatement)7 SQLDeleteStatement (com.alibaba.druid.sql.ast.statement.SQLDeleteStatement)7 SQLInsertStatement (com.alibaba.druid.sql.ast.statement.SQLInsertStatement)6 SQLExprTableSource (com.alibaba.druid.sql.ast.statement.SQLExprTableSource)4 SQLSelectStatement (com.alibaba.druid.sql.ast.statement.SQLSelectStatement)4 SQLStatementParser (com.alibaba.druid.sql.parser.SQLStatementParser)4 SQLTableSource (com.alibaba.druid.sql.ast.statement.SQLTableSource)3 MySqlStatementParser (com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser)3 MySqlSchemaStatVisitor (com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor)3 Column (com.alibaba.druid.stat.TableStat.Column)3 DbType (com.alibaba.druid.DbType)2 SQLAlterTableDropConstraint (com.alibaba.druid.sql.ast.statement.SQLAlterTableDropConstraint)2 SQLAlterTableDropIndex (com.alibaba.druid.sql.ast.statement.SQLAlterTableDropIndex)2 SQLAlterTableStatement (com.alibaba.druid.sql.ast.statement.SQLAlterTableStatement)2 SQLCreateDatabaseStatement (com.alibaba.druid.sql.ast.statement.SQLCreateDatabaseStatement)2 SQLCreateTableStatement (com.alibaba.druid.sql.ast.statement.SQLCreateTableStatement)2 SQLDropTableStatement (com.alibaba.druid.sql.ast.statement.SQLDropTableStatement)2 SQLSelectQueryBlock (com.alibaba.druid.sql.ast.statement.SQLSelectQueryBlock)2