use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.
the class SQLUpdateBuilderImpl method set.
public SQLUpdateBuilderImpl set(String... items) {
SQLUpdateStatement update = getSQLUpdateStatement();
for (String item : items) {
SQLUpdateSetItem updateSetItem = SQLUtils.toUpdateSetItem(item, dbType);
update.addItem(updateSetItem);
}
return this;
}
use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.
the class WallProvider method checkInternal.
private WallCheckResult checkInternal(String sql) {
checkCount.incrementAndGet();
WallContext context = WallContext.current();
if (config.isDoPrivilegedAllow() && ispPrivileged()) {
WallCheckResult checkResult = new WallCheckResult();
checkResult.setSql(sql);
return checkResult;
}
// first step, check whiteList
boolean mulltiTenant = config.getTenantTablePattern() != null && config.getTenantTablePattern().length() > 0;
if (!mulltiTenant) {
WallCheckResult checkResult = checkWhiteAndBlackList(sql);
if (checkResult != null) {
checkResult.setSql(sql);
return checkResult;
}
}
hardCheckCount.incrementAndGet();
final List<Violation> violations = new ArrayList<Violation>();
List<SQLStatement> statementList = new ArrayList<SQLStatement>();
boolean syntaxError = false;
boolean endOfComment = false;
try {
SQLStatementParser parser = createParser(sql);
parser.getLexer().setCommentHandler(WallCommentHandler.instance);
if (!config.isCommentAllow()) {
// deny comment
parser.getLexer().setAllowComment(false);
}
if (!config.isCompleteInsertValuesCheck()) {
parser.setParseCompleteValues(false);
parser.setParseValuesSize(config.getInsertValuesCheckSize());
}
parser.parseStatementList(statementList);
final Token lastToken = parser.getLexer().token();
if (lastToken != Token.EOF && config.isStrictSyntaxCheck()) {
violations.add(new IllegalSQLObjectViolation(ErrorCode.SYNTAX_ERROR, "not terminal sql, token " + lastToken, sql));
}
endOfComment = parser.getLexer().isEndOfComment();
} catch (NotAllowCommentException e) {
violations.add(new IllegalSQLObjectViolation(ErrorCode.COMMENT_STATEMENT_NOT_ALLOW, "comment not allow", sql));
incrementCommentDeniedCount();
} catch (ParserException e) {
syntaxErrorCount.incrementAndGet();
syntaxError = true;
if (config.isStrictSyntaxCheck()) {
violations.add(new SyntaxErrorViolation(e, sql));
}
} catch (Exception e) {
if (config.isStrictSyntaxCheck()) {
violations.add(new SyntaxErrorViolation(e, sql));
}
}
if (statementList.size() > 1 && !config.isMultiStatementAllow()) {
violations.add(new IllegalSQLObjectViolation(ErrorCode.MULTI_STATEMENT, "multi-statement not allow", sql));
}
WallVisitor visitor = createWallVisitor();
visitor.setSqlEndOfComment(endOfComment);
if (statementList.size() > 0) {
boolean lastIsHint = false;
for (int i = 0; i < statementList.size(); i++) {
SQLStatement stmt = statementList.get(i);
if ((i == 0 || lastIsHint) && stmt instanceof MySqlHintStatement) {
lastIsHint = true;
continue;
}
try {
stmt.accept(visitor);
} catch (ParserException e) {
violations.add(new SyntaxErrorViolation(e, sql));
}
}
}
if (visitor.getViolations().size() > 0) {
violations.addAll(visitor.getViolations());
}
Map<String, WallSqlTableStat> tableStat = context.getTableStats();
boolean updateCheckHandlerEnable = false;
{
WallUpdateCheckHandler updateCheckHandler = config.getUpdateCheckHandler();
if (updateCheckHandler != null) {
for (SQLStatement stmt : statementList) {
if (stmt instanceof SQLUpdateStatement) {
SQLUpdateStatement updateStmt = (SQLUpdateStatement) stmt;
SQLName table = updateStmt.getTableName();
if (table != null) {
String tableName = table.getSimpleName();
Set<String> updateCheckColumns = config.getUpdateCheckTable(tableName);
if (updateCheckColumns != null && updateCheckColumns.size() > 0) {
updateCheckHandlerEnable = true;
break;
}
}
}
}
}
}
WallSqlStat sqlStat = null;
if (violations.size() > 0) {
violationCount.incrementAndGet();
if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
sqlStat = addBlackSql(sql, tableStat, context.getFunctionStats(), violations, syntaxError);
}
} else {
if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
boolean selectLimit = false;
if (config.getSelectLimit() > 0) {
for (SQLStatement stmt : statementList) {
if (stmt instanceof SQLSelectStatement) {
selectLimit = true;
break;
}
}
}
if (!selectLimit) {
sqlStat = addWhiteSql(sql, tableStat, context.getFunctionStats(), syntaxError);
}
}
}
if (sqlStat == null && updateCheckHandlerEnable) {
sqlStat = new WallSqlStat(tableStat, context.getFunctionStats(), violations, syntaxError);
}
Map<String, WallSqlTableStat> tableStats = null;
Map<String, WallSqlFunctionStat> functionStats = null;
if (context != null) {
tableStats = context.getTableStats();
functionStats = context.getFunctionStats();
recordStats(tableStats, functionStats);
}
WallCheckResult result;
if (sqlStat != null) {
context.setSqlStat(sqlStat);
result = new WallCheckResult(sqlStat, statementList);
} else {
result = new WallCheckResult(null, violations, tableStats, functionStats, statementList, syntaxError);
}
String resultSql;
if (visitor.isSqlModified()) {
resultSql = SQLUtils.toSQLString(statementList, dbType);
} else {
resultSql = sql;
}
result.setSql(resultSql);
result.setUpdateCheckItems(visitor.getUpdateCheckItems());
return result;
}
use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.
the class SQLASTVisitorAdapterTest method test_adapter.
public void test_adapter() throws Exception {
SQLASTVisitorAdapter adapter = new SQLASTVisitorAdapter();
new SQLBinaryOpExpr().accept(adapter);
new SQLInListExpr().accept(adapter);
new SQLSelectQueryBlock().accept(adapter);
new SQLDropTableStatement().accept(adapter);
new SQLCreateTableStatement().accept(adapter);
new SQLDeleteStatement().accept(adapter);
new SQLCurrentOfCursorExpr().accept(adapter);
new SQLInsertStatement().accept(adapter);
new SQLUpdateStatement().accept(adapter);
new SQLNotNullConstraint().accept(adapter);
new SQLMethodInvokeExpr().accept(adapter);
new SQLCallStatement().accept(adapter);
new SQLSomeExpr().accept(adapter);
new SQLAnyExpr().accept(adapter);
new SQLAllExpr().accept(adapter);
new SQLDefaultExpr().accept(adapter);
new SQLCommentStatement().accept(adapter);
new SQLDropViewStatement().accept(adapter);
new SQLSavePointStatement().accept(adapter);
new SQLReleaseSavePointStatement().accept(adapter);
new SQLCreateDatabaseStatement().accept(adapter);
new SQLAlterTableDropIndex().accept(adapter);
new SQLOver().accept(adapter);
new SQLWithSubqueryClause().accept(adapter);
new SQLAlterTableAlterColumn().accept(adapter);
new SQLAlterTableStatement().accept(adapter);
new SQLAlterTableDisableConstraint().accept(adapter);
new SQLAlterTableEnableConstraint().accept(adapter);
new SQLColumnCheck().accept(adapter);
new SQLExprHint().accept(adapter);
new SQLAlterTableDropConstraint().accept(adapter);
}
use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.
the class MySqlUpdateTest_15 method test_0.
public void test_0() throws Exception {
String sql = "update students set name='test' where id in (select stu_id from score where s <100)";
MySqlStatementParser parser = new MySqlStatementParser(sql);
List<SQLStatement> statementList = parser.parseStatementList();
SQLStatement stmt = statementList.get(0);
print(statementList);
assertEquals(1, statementList.size());
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor);
// System.out.println("Tables : " + visitor.getTables());
// System.out.println("fields : " + visitor.getColumns());
// System.out.println("coditions : " + visitor.getConditions());
// System.out.println("orderBy : " + visitor.getOrderByColumns());
assertEquals(2, visitor.getTables().size());
assertEquals(4, visitor.getColumns().size());
// assertEquals(2, visitor.getConditions().size());
assertTrue(visitor.containsTable("students"));
assertTrue(visitor.containsTable("score"));
assertTrue(visitor.getColumns().contains(new Column("students", "name")));
assertTrue(visitor.getColumns().contains(new Column("score", "stu_id")));
{
String output = SQLUtils.toMySqlString(stmt);
assertEquals("UPDATE students\n" + "SET name = 'test'\n" + "WHERE id IN (\n" + "\t\tSELECT stu_id\n" + "\t\tFROM score\n" + "\t\tWHERE s < 100\n" + //
"\t)", output);
}
{
String output = SQLUtils.toMySqlString(stmt, SQLUtils.DEFAULT_LCASE_FORMAT_OPTION);
assertEquals("update students\n" + "set name = 'test'\n" + "where id in (\n" + "\t\tselect stu_id\n" + "\t\tfrom score\n" + "\t\twhere s < 100\n" + //
"\t)", output);
}
assertTrue(WallUtils.isValidateMySql(sql));
{
SQLUpdateStatement update = (SQLUpdateStatement) stmt;
SQLExpr where = update.getWhere();
assertEquals("id IN (\n" + "\tSELECT stu_id\n" + "\tFROM score\n" + "\tWHERE s < 100\n" + ")", where.toString());
}
}
use of com.alibaba.druid.sql.ast.statement.SQLUpdateStatement in project druid by alibaba.
the class MySqlUpdateTest_17 method test_0.
public void test_0() throws Exception {
String sql = "update security_group_ip_count set ip_count=GREATEST(ip_count-?, 0), gmt_modified=now() where group_id=? ";
MySqlStatementParser parser = new MySqlStatementParser(sql);
List<SQLStatement> statementList = parser.parseStatementList();
SQLStatement stmt = statementList.get(0);
print(statementList);
assertEquals(1, statementList.size());
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
stmt.accept(visitor);
// System.out.println("Tables : " + visitor.getTables());
// System.out.println("fields : " + visitor.getColumns());
// System.out.println("coditions : " + visitor.getConditions());
// System.out.println("orderBy : " + visitor.getOrderByColumns());
assertEquals(1, visitor.getTables().size());
assertEquals(3, visitor.getColumns().size());
// assertEquals(2, visitor.getConditions().size());
assertTrue(visitor.containsTable("security_group_ip_count"));
assertTrue(visitor.getColumns().contains(new Column("security_group_ip_count", "ip_count")));
assertEquals("UPDATE security_group_ip_count\n" + "SET ip_count = GREATEST(ip_count - ?, 0), gmt_modified = now()\n" + //
"WHERE group_id = ?", stmt.toString());
assertEquals("update security_group_ip_count\n" + "set ip_count = GREATEST(ip_count - ?, 0), gmt_modified = now()\n" + //
"where group_id = ?", stmt.toLowerCaseString());
assertTrue(WallUtils.isValidateMySql(sql));
{
SQLUpdateStatement update = (SQLUpdateStatement) stmt;
SQLExpr where = update.getWhere();
assertEquals("group_id = ?", where.toString());
}
}
Aggregations