/*******************************************************************************
* Copyright 2013 butor.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package org.butor.dao;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.sql.DataSource;
import org.butor.dao.extractor.MaxRowsResultSetExtractor;
import org.butor.utils.ApplicationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowCallbackHandler;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.BeanPropertySqlParameterSource;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;
import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator;
import org.springframework.jdbc.support.SQLExceptionTranslator;
import org.springframework.util.CollectionUtils;
import com.google.common.base.Strings;
import com.google.common.base.Throwables;
public abstract class AbstractDao {
private NamedParameterJdbcOperations namedParameterJdbcTemplate = null;
private DataSource dataSource = null;
private String authDataSql = null;
private static final SQLExceptionTranslator DEFAULT_DAO_EXCEPTION_TRANSLATOR = new SQLErrorCodeSQLExceptionTranslator();
protected SQLExceptionTranslator daoExceptionTranslator = DEFAULT_DAO_EXCEPTION_TRANSLATOR;
private SqlLogger sqlLogger = new DefaultSqlLogger();
protected Logger logger = LoggerFactory.getLogger(getClass());
private boolean warnParamOverride = true;
private static final String PROC_NAME_HISTORIZE = "__insertHistory";
private String insertHistorySql = null;
private abstract class CallTemplate<T> {
abstract T doCall(String sql, MapSqlParameterSource params);
T call(String procName, String sql, Object... args) {
long start = System.currentTimeMillis();
if (procName == null) {
procName = guessSqlProcName();
}
boolean success = false;
T result = null;
MapSqlParameterSource params = null;
try {
if (args.length == 1 && args[0] instanceof MapSqlParameterSource) {
params = (MapSqlParameterSource)args[0];
} else {
params = prepParams(args);
}
result = doCall(sql, params);
success = true;
return result;
} catch (Exception e) {
translateException("call", sql, e);
return null;
} finally {
long elapsed = System.currentTimeMillis() - start;
if (sqlLogger != null) {
sqlLogger.logQuery(procName, sql, params, success, elapsed, result);
}
}
}
}
private String guessSqlProcName() {
// use the call stack trace to determine the Sql procedure name.
// it is the class and method name that derive from THIS (AbstractDao) class.
// the first element in the stack is the Thread.getStackTrace.
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
String procName = "?sqlProcName?";
String daoCallingClsName = this.getClass().getName();
for (StackTraceElement ste : stack) {
if (ste.getClassName().indexOf(daoCallingClsName) == 0) {
procName = ste.getClassName() +"." +ste.getMethodName();
break;
}
}
return procName;
}
private String checkAuthPlaceholder(String sql) {
while (true) {
int pos = sql.indexOf("__authDataSql__");
if (pos == -1) {
break;
}
sql = sql.replace("__authDataSql__", authDataSql);
}
return sql;
}
protected String queryForString(String procName, String sql, Object... args) throws DataAccessException {
sql = checkAuthPlaceholder(sql);
return new CallTemplate<String>() {
@Override
String doCall(String sql, MapSqlParameterSource params) {
return (String) getJdbcTemplate().queryForObject(sql, params, String.class);
}
}.call(procName, sql, args);
}
protected int queryForInt(String procName, String sql, Object... args) throws DataAccessException {
sql = checkAuthPlaceholder(sql);
return new CallTemplate<Integer>() {
@Override
Integer doCall(String sql, MapSqlParameterSource params) {
return getJdbcTemplate().queryForInt(sql, params);
}
}.call(procName, sql, args);
}
protected <T> T query(String procName, String sql, Class<T> resultSetClass, Object... args) throws DataAccessException {
return this.query(procName, sql, resultSetClass, null, args);
}
protected <T> T query(String procName, String sql, final Class<T> resultSetClass, final ResultSetExtractor<T> extractor,
Object... args) throws DataAccessException {
sql = checkAuthPlaceholder(sql);
return new CallTemplate<T>() {
@Override
T doCall(String sql, MapSqlParameterSource params) {
ResultSetExtractor<T> ext = extractor;
if (ext == null)
ext = getDefaultResultSetExtractor(resultSetClass);
List<T> t = (List<T>) getJdbcTemplate().query(sql, params, ext);
if (CollectionUtils.isEmpty(t)) {
return null;
} else {
return t.get(0);
}
}
}.call(procName, sql, args);
}
protected <T> void queryList(String procName, String sql, final Class<T> resultSetClass, final RowHandler<T> handler,
Object... args) {
sql = checkAuthPlaceholder(sql);
new CallTemplate<Integer>() {
@Override
Integer doCall(String sql, MapSqlParameterSource params) {
final RowMapper<T> mapper = new BeanPropertyRowMapper<T>(resultSetClass);
NamedParameterJdbcTemplate template = new NamedParameterJdbcTemplate(dataSource) {
@Override
protected PreparedStatementCreator getPreparedStatementCreator(String sql,
SqlParameterSource paramSource) {
PreparedStatementCreator originalCreator = super.getPreparedStatementCreator(sql, paramSource);
return new StreamingStatementCreator(originalCreator);
}
};
final AtomicInteger rowNum = new AtomicInteger(0);
template.query(sql, params, new RowCallbackHandler() {
@Override
public void processRow(ResultSet rs) throws SQLException {
handler.handleRow(mapper.mapRow(rs, rowNum.incrementAndGet()));
}
});
return new Integer(rowNum.get());
}
}.call(procName, sql, args);
}
protected <T> T queryFirst(String procName, String sql, Class<T> resultSetClass, Object... args) throws DataAccessException {
List<T> list = this.queryList(procName, sql, resultSetClass, args);
if (list == null || list.size() == 0)
return null;
return list.get(0);
}
protected <T> List<T> queryList(String procName, String sql, Class<T> resultSetClass, Object... args) throws DataAccessException {
return this.queryList(procName, sql, resultSetClass, (ResultSetExtractor<T>) null, args);
}
protected <T> List<T> queryList(String procName, String sql, final Class<T> resultSetClass,
final ResultSetExtractor<T> extractor, Object... args) throws DataAccessException {
sql = checkAuthPlaceholder(sql);
return new CallTemplate<List<T>>() {
@Override
List<T> doCall(String sql, MapSqlParameterSource params) {
ResultSetExtractor<T> ext = extractor;
if (ext == null)
ext = getDefaultResultSetExtractor(resultSetClass);
List<T> t = (List<T>) getJdbcTemplate().query(sql, params, ext);
if (CollectionUtils.isEmpty(t)) {
t = Collections.emptyList();
}
return t;
}
}.call(procName, sql, args);
}
/**
* @see JdbcTemplate.update
* @param procName
* @param sql
* @param args
* @return the affected
* @throws DataAccessException
*/
protected UpdateResult update(String procName, String sql, Object... args) throws DataAccessException {
checkForHistory("UPDATE", procName, args);
return updateInternal(procName, sql, args);
}
private UpdateResult updateInternal(String procName, String sql, Object... args) throws DataAccessException {
sql = checkAuthPlaceholder(sql);
return new CallTemplate<UpdateResult>() {
@Override
UpdateResult doCall(String sql, MapSqlParameterSource params) {
KeyHolder kh = new GeneratedKeyHolder();
int rowsAffected = getJdbcTemplate().update(sql, params, kh);
return new UpdateResult(kh, rowsAffected);
}
}.call(procName, sql, args);
}
protected UpdateResult insert(String procName, String sql, Object... args) throws DataAccessException {
return updateInternal(procName, sql, args);
}
protected UpdateResult delete(String procName, String sql, Object... args) throws DataAccessException {
checkForHistory("DELETE", procName, args);
return updateInternal(procName, sql, args);
}
private <T> ResultSetExtractor<T> getDefaultResultSetExtractor(Class<T> resultSetClass) {
ResultSetExtractor<T> maxRowsResultSetExtractor = new MaxRowsResultSetExtractor(resultSetClass);
return maxRowsResultSetExtractor;
}
private NamedParameterJdbcOperations getJdbcTemplate() {
if (namedParameterJdbcTemplate == null)
namedParameterJdbcTemplate = new NamedParameterJdbcTemplate(dataSource);
return namedParameterJdbcTemplate;
}
protected DataAccessException translateException(String task_, String sql, Exception ex_) {
if (ex_ instanceof SQLException) {
DataAccessException cause = daoExceptionTranslator.translate(this.getClass().getName(), sql,
(SQLException) ex_);
throw ApplicationException.exception(cause, DAOMessageID.SQL_EXCEPTION.getMessage());
} else if (ex_ instanceof DataAccessException) {
Throwable cause = ex_.getCause();
if (cause != null) {
throw ApplicationException.exception(cause, DAOMessageID.SQL_EXCEPTION.getMessage(cause.getMessage()));
} else {
throw ApplicationException.exception(ex_, DAOMessageID.SQL_EXCEPTION.getMessage(ex_.getMessage()));
}
}
throw ApplicationException.exception(
String.format("Failed to execute %s.%s with SQL=%s", this.getClass().getName(), task_, sql), ex_,
DAOMessageID.DAO_FAILURE.getMessage());
}
protected MapSqlParameterSource prepParams(Object... args) {
MapSqlParameterSource msp = new MapSqlParameterSource();
for (Object arg : args) {
if (null == arg) {
continue;
}
if (arg instanceof Map) {
msp.addValues((Map) arg);
continue;
}
if (arg instanceof MapSqlParameterSource) {
msp.addValues(((MapSqlParameterSource) arg).getValues());
continue;
}
BeanPropertySqlParameterSource sps = new BeanPropertySqlParameterSource(arg);
List<Field> fields = new ArrayList<Field>();
for (Class<?> c = arg.getClass(); c != null; c = c.getSuperclass()) {
if (c.equals(Object.class))
continue;
fields.addAll(Arrays.asList(c.getDeclaredFields()));
}
for (Field f : fields) {
try {
String fn = f.getName();
if (msp.hasValue(fn)) {
if (warnParamOverride) {
warnParamOverride = false;
logger.warn(String.format("Field with name=%s has "
+ "been already mapped by another arg bean. Overriding! Next time will warn if DEBUG is enabled.", fn));
} else {
if (logger.isDebugEnabled()) {
logger.warn(String.format("Field with name=%s has "
+ "been already mapped by another arg bean. Overriding!", fn));
}
}
}
if (Enum.class.isAssignableFrom(f.getType())) {
sps.registerSqlType(f.getName(), Types.VARCHAR);
}
msp.addValue(fn, sps.getValue(fn), sps.getSqlType(fn), sps.getTypeName(fn));
logger.debug(String.format("prepared sql arg: name=%s, value=%s, type=%s", fn, sps.getValue(fn),
sps.getTypeName(fn)));
} catch (Exception e) {
Throwables.propagate(e);
}
}
}
return msp;
}
public void setDaoExceptionMapper(SQLExceptionTranslator daoExceptionTranslator_) {
daoExceptionTranslator = daoExceptionTranslator_;
}
public void setDataSource(DataSource dataSource_) {
dataSource = dataSource_;
}
public void setAuthDataSql(String authDataSql_) {
authDataSql = authDataSql_;
}
protected static class UpdateResult {
public final Long key;
public final int numberOfRowAffected;
public UpdateResult(KeyHolder keyHolder, int numberOfRowAffected) {
this.numberOfRowAffected = numberOfRowAffected;
List<Map<String, Object>> keys = keyHolder.getKeyList();
if (keys.size() != 1) {
key = null;
return;
}
Iterator<Object> keyIt = keys.get(0).values().iterator();
if (keyIt.hasNext()) {
Object keyObj = keyIt.next();
if (keyObj instanceof Number) {
this.key = keyObj == null ? null : ((Number) keyObj).longValue();
return;
}
}
key = null;
}
@Override
public String toString() {
return "UpdateResult [key=" + key + ", numberOfRowAffected=" + numberOfRowAffected + "]";
}
}
private class StreamingStatementCreator implements PreparedStatementCreator {
final PreparedStatementCreator delegate;
public StreamingStatementCreator(PreparedStatementCreator delegate) {
this.delegate = delegate;
}
@Override
public PreparedStatement createPreparedStatement(Connection connection) throws SQLException {
final PreparedStatement statement = delegate.createPreparedStatement(connection);
statement.setFetchDirection(ResultSet.FETCH_FORWARD);
try {
// this is for mysql streaming. It makes the driver send the rows as soon as it gets it from the
// database.
statement.setFetchSize(Integer.MIN_VALUE);
} catch (SQLException e) {
logger.warn("Unable to set fetch size to MIN_VALUE for enabling streaming in some DB engine");
}
return statement;
}
}
public void setSqlLogger(SqlLogger sqlLogger) {
this.sqlLogger = sqlLogger;
}
private void checkForHistory(String operation, String procName, Object ... args) {
if (this instanceof DaoWithHistory) {
DaoWithHistory dwh = (DaoWithHistory) this;
if (!Strings.isNullOrEmpty(procName) && !procName.equals(PROC_NAME_HISTORIZE)) {
try {
MapSqlParameterSource params = prepParams(args);
Object original = null;
original = dwh.getRowForHistory(params);
if (insertHistorySql == null) {
String insertSql = dwh.getInsertSql();
Pattern p1 = Pattern.compile("(?i)insert[^\\(]*\\(");
Pattern p2 = Pattern.compile("(?i)values[^\\(]*\\(");
Matcher m = p1.matcher(insertSql);
if (!m.find()) {
logger.warn("Could not find pattern \"INSERT INTO ... (\" in insert SQL: {}. Ignoring history.", insertSql);
return;
}
int re = m.end();
insertSql = insertSql.substring(0, re-1).replace("[\n\r]", "").trim() +"Hist (" +"histId, histOperation, histStamp, histUserId," +insertSql.substring(re);
m = p2.matcher(insertSql);
if (!m.find()) {
logger.warn("Could not find pattern \"VALUES (\" in insert SQL: {}. Ignoring history.", insertSql);
return;
}
re = m.end();
insertSql = insertSql.substring(0, re) +":histId, :histOperation, CURRENT_TIMESTAMP, :histUserId," +insertSql.substring(re);
insertHistorySql = insertSql;
}
params.addValue("histId",0);
params.addValue("histOperation", operation);
params.addValue("histStamp","");
params.addValue("histUserId", (String)params.getValue("userId"));
insert("__insertHistory", insertHistorySql, params, original);
} catch (Throwable e) {
logger.warn("Failed to insert history!", e);
}
}
}
}
}