Package com.alibaba.druid.wall

Source Code of com.alibaba.druid.wall.WallProvider

/*
* Copyright 1999-2011 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.druid.wall;

import java.security.PrivilegedAction;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.parser.NotAllowCommentException;
import com.alibaba.druid.sql.parser.ParserException;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.sql.parser.Token;
import com.alibaba.druid.sql.visitor.ExportParameterVisitor;
import com.alibaba.druid.util.LRUCache;
import com.alibaba.druid.wall.spi.WallVisitorUtils;
import com.alibaba.druid.wall.violation.IllegalSQLObjectViolation;
import com.alibaba.druid.wall.violation.SyntaxErrorViolation;

public abstract class WallProvider {

    private LinkedHashMap<String, WallSqlStat>        whiteList;

    private int                                       whileListMaxSize  = 1024;

    private int                                       whiteSqlMaxLength = 1024;                                         // 1k

    protected final WallConfig                        config;

    private final ReentrantReadWriteLock              lock              = new ReentrantReadWriteLock();

    private static final ThreadLocal<Boolean>         privileged        = new ThreadLocal<Boolean>();

    private final ConcurrentMap<String, WallDenyStat> deniedTables      = new ConcurrentHashMap<String, WallDenyStat>();
    private final ConcurrentMap<String, WallDenyStat> deniedFunctions   = new ConcurrentHashMap<String, WallDenyStat>();
    private final ConcurrentMap<String, WallDenyStat> deniedSchemas     = new ConcurrentHashMap<String, WallDenyStat>();

    public final WallDenyStat                         commentDeniedStat = new WallDenyStat();

    protected String                                  dbType            = null;

    public WallProvider(WallConfig config){
        this.config = config;
    }

    public WallProvider(WallConfig config, String dbType){
        this.config = config;
        this.dbType = dbType;
    }

    public WallConfig getConfig() {
        return config;
    }

    public void addWhiteSql(String sql) {
        lock.writeLock().lock();
        try {
            if (whiteList == null) {
                whiteList = new LRUCache<String, WallSqlStat>(whileListMaxSize);
            }

            whiteList.put(sql, new WallSqlStat());
        } finally {
            lock.writeLock().unlock();
        }
    }

    public Set<String> getWhiteList() {
        Set<String> hashSet = new HashSet<String>();
        lock.readLock().lock();
        try {
            if (whiteList != null) {
                hashSet.addAll(whiteList.keySet());
            }
        } finally {
            lock.readLock().unlock();
        }

        return hashSet;
    }

    public void clearCache() {
        clearWhiteList();
    }

    public void clearWhiteList() {
        lock.writeLock().lock();
        try {
            if (whiteList != null) {
                whiteList = null;
            }
        } finally {
            lock.writeLock().unlock();
        }
    }

    public WallSqlStat getWhiteSql(String sql) {
        lock.readLock().lock();
        try {
            if (whiteList == null) {
                return null;
            }

            return whiteList.get(sql);
        } finally {
            lock.readLock().unlock();
        }
    }

    public boolean whiteContains(String sql) {
        return getWhiteSql(sql) != null;
    }

    public abstract SQLStatementParser createParser(String sql);

    public abstract WallVisitor createWallVisitor();

    public abstract ExportParameterVisitor createExportParameterVisitor();

    public boolean checkValid(String sql) {
        WallContext originalContext = WallContext.current();

        try {
            WallContext.createIfNotExists(dbType);
            WallCheckResult result = check(sql);
            return result.getViolations().isEmpty();
        } finally {
            if (originalContext == null) {
                WallContext.clearContext();
            }
        }
    }

    public void incrementCommentDeniedCount() {
        this.commentDeniedStat.incrementAndGetDenyCount();
    }

    public boolean checkDenyFunction(String functionName) {
        if (functionName == null) {
            return true;
        }

        functionName = functionName.toLowerCase();
        if (getConfig().getDenyFunctions().contains(functionName)) {
            WallDenyStat denyStat = this.deniedFunctions.get(functionName);
            if (denyStat == null) {
                this.deniedFunctions.putIfAbsent(functionName, new WallDenyStat());
                denyStat = this.deniedFunctions.get(functionName);
            }
            denyStat.incrementAndGetDenyCount();
            return false;
        }

        return true;
    }

    public boolean checkDenySchema(String schemaName) {
        if (schemaName == null) {
            return true;
        }

        schemaName = schemaName.toLowerCase();
        if (getConfig().getDenySchemas().contains(schemaName)) {
            WallDenyStat denyStat = this.deniedSchemas.get(schemaName);
            if (denyStat == null) {
                this.deniedSchemas.putIfAbsent(schemaName, new WallDenyStat());
                denyStat = this.deniedSchemas.get(schemaName);
            }
            denyStat.incrementAndGetDenyCount();
            return false;
        }

        return true;
    }

    public boolean checkDenyTable(String tableName) {
        if (tableName == null) {
            return true;
        }

        tableName = WallVisitorUtils.form(tableName);
        if (getConfig().getDenyTables().contains(tableName)) {
            WallDenyStat denyStat = this.deniedTables.get(tableName);
            if (denyStat == null) {
                this.deniedTables.putIfAbsent(tableName, new WallDenyStat());
                denyStat = this.deniedTables.get(tableName);
            }
            denyStat.incrementAndGetDenyCount();
            return false;
        }

        return true;
    }

    public boolean checkReadOnlyTable(String tableName) {
        if (tableName == null) {
            return true;
        }

        tableName = WallVisitorUtils.form(tableName);
        if (getConfig().isReadOnly(tableName)) {
            WallDenyStat denyStat = this.deniedTables.get(tableName);
            if (denyStat == null) {
                this.deniedTables.putIfAbsent(tableName, new WallDenyStat());
                denyStat = this.deniedTables.get(tableName);
            }
            denyStat.incrementAndGetDenyCount();
            return false;
        }

        return true;
    }

    public WallDenyStat getDenniedTableStat(String tableName) {
        if (tableName == null) {
            return null;
        }

        String formName = WallVisitorUtils.form(tableName);
        return deniedTables.get(formName);
    }

    public WallDenyStat getDenniedSchemaStat(String schemaName) {
        if (schemaName == null) {
            return null;
        }

        String formName = WallVisitorUtils.form(schemaName);
        return deniedSchemas.get(formName);
    }

    public WallDenyStat getCommentDenyStat() {
        return this.commentDeniedStat;
    }

    public WallCheckResult check(String sql) {
        WallCheckResult result = new WallCheckResult();

        if (config.isDoPrivilegedAllow() && ispPivileged()) {
            return result;
        }

        // first step, check whiteList
        WallSqlStat sqlStat = getWhiteSql(sql);
        if (sqlStat != null) {
            sqlStat.incrementAndGetExecuteCount();
            return result;
        }

        SQLStatementParser parser = createParser(sql);

        if (!config.isCommentAllow()) {
            parser.getLexer().setAllowComment(false); // deny comment
        }

        try {
            parser.parseStatementList(result.getStatementList());
        } catch (NotAllowCommentException e) {
            result.getViolations().add(new SyntaxErrorViolation(e, sql));
            incrementCommentDeniedCount();
            return result;
        } catch (ParserException e) {
            if (config.isStrictSyntaxCheck()) {
                result.getViolations().add(new SyntaxErrorViolation(e, sql));
            }
            return result;
        } catch (Exception e) {
            result.getViolations().add(new SyntaxErrorViolation(e, sql));
            return result;
        }

        if (parser.getLexer().token() != Token.EOF) {
            result.getViolations().add(new IllegalSQLObjectViolation(sql));
            return result;
        }

        if (result.getStatementList().size() == 0) {
            return result;
        }

        if (result.getStatementList().size() > 1 && !config.isMultiStatementAllow()) {
            result.getViolations().add(new IllegalSQLObjectViolation(sql));
            return result;
        }

        SQLStatement stmt = result.getStatementList().get(0);

        WallVisitor visitor = createWallVisitor();

        try {
            stmt.accept(visitor);
        } catch (ParserException e) {
            result.getViolations().add(new IllegalSQLObjectViolation(sql));
            return result;
        }

        if (visitor.getViolations().size() == 0) {
            if (sql.length() < whiteSqlMaxLength) {
                this.addWhiteSql(sql);
            }
        }

        result.getViolations().addAll(visitor.getViolations());
        return result;
    }

    public static boolean ispPivileged() {
        Boolean value = privileged.get();
        if (value == null) {
            return false;
        }

        return value.booleanValue();
    }

    public static <T> T doPrivileged(PrivilegedAction<T> action) {
        privileged.set(Boolean.TRUE);
        try {
            return action.run();
        } finally {
            privileged.set(null);
        }
    }

    private static final ThreadLocal<Object> tenantValueLocal = new ThreadLocal<Object>();

    public static void setTenantValue(Object value) {
        tenantValueLocal.set(value);
    }

    public static Object getTenantValue() {
        return tenantValueLocal.get();
    }
}
TOP

Related Classes of com.alibaba.druid.wall.WallProvider

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.