package com.common.test;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import javax.sql.DataSource;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.AmbiguousTableNameException;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.QueryDataSet;
import org.dbunit.dataset.CachedDataSet;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.xml.FlatXmlDataSet;
import org.dbunit.operation.DatabaseOperation;
import org.springframework.util.ResourceUtils;
/**
* 数据库测试有用的类,用于装载Dbunit测试数据,恢复测试数据
* @author badqiu1
*
*/
public class DBUnitFlatXmlHelper {
private DatabaseConnection databaseConnection; // DbUnit的数据库连接
private CachedDataSet cachedDataSet; // 存储临时数据
private DataSource dataSource; //数据源
private List testDataSets = new ArrayList(); //插入数据库的临时数据
private String jdbcSchema = null;
public DBUnitFlatXmlHelper() {
}
public DBUnitFlatXmlHelper(DataSource d) throws Exception {
this(d,null);
}
public DBUnitFlatXmlHelper(DataSource dataSource, String jdbcSchema) throws Exception {
this.jdbcSchema = jdbcSchema;
setDataSource(dataSource);
}
public void setJdbcSchema(String jdbcSchema) {
this.jdbcSchema = jdbcSchema;
}
public void insertTestDatas(String... flatXmlDataFiles) throws DatabaseUnitException, SQLException, IOException {
if(flatXmlDataFiles == null) return;
for (String dataFile : flatXmlDataFiles) {
try {
File file = ResourceUtils.getFile(dataFile);
System.out.println("[DbUnit INFO] insert test dataFile,filepath="+dataFile);
insertTestData(file);
} catch (FileNotFoundException e) {
System.out.println("[DbUnit WARN] not found DbUnit DataFile,filepath="+ dataFile);
continue;
}
}
}
/**
* 插入测试数据,数据格式为DBUnit的flatXMLFile,
*
* 如:<BR>
* <dataset><BR>
* <TABLE_NAME field1="username1" <BR>
* field2="pwd1"/> <BR>
* <TABLE_NAME field2="password" <BR>
* field1="username2"/><BR>
* </dataset><BR>
* USER为表名,username,password为字段名
* @param flatXMLFile
* @throws FileNotFoundException
* @throws IOException
* @throws DatabaseUnitException
* @throws SQLException
* @link http://dbunit.sourceforge.net/components.html#dataset
*/
public void insertTestData(File flatXMLFile) throws FileNotFoundException, IOException, DatabaseUnitException, SQLException {
try {
IDataSet testDataSet = new FlatXmlDataSet(new FileInputStream(flatXMLFile));
DatabaseOperation.REFRESH.execute(getDatabaseConnection(),testDataSet);
testDataSets.add(testDataSet);
}catch(AmbiguousTableNameException e) {
throw new DatabaseUnitException("出现AmbiguousTableNameException异常,使用命令:purge recyclebin清空一下oracle回收站,并为dbunit指定jdbcSchema",e);
}
}
/**删除测试数据*/
public void deleteTestDatas() throws DatabaseUnitException, SQLException {
System.out.println("[DbUnit INFO] delete test datas");
for(ListIterator it = testDataSets.listIterator(); it.hasNext();) {
IDataSet testDataSet = (IDataSet)it.next();
DatabaseOperation.DELETE.execute(getDatabaseConnection(),testDataSet);
it.remove();
}
testDataSets.clear();
}
/**
* 备份单个表
*
* @param tableName
* @throws Exception
*/
public void backupTable(String tableName) throws Exception {
String[] tableNames = { tableName };
backupTables(tableNames);
}
/**
* 备份数据库中相关的表
*
* @param tableNames
* 数据库中的表名
*/
public void backupTables(String[] tableNames) throws Exception {
cachedDataSet = new CachedDataSet(getDatabaseConnection().createDataSet(
tableNames));
}
/**
* 恢复数据至前一备份状态,这里执行删除后添加操作
*/
public void restoreTables() throws Exception {
deleteTestDatas();
if(cachedDataSet != null) {
DatabaseOperation.CLEAN_INSERT.execute(getDatabaseConnection(),cachedDataSet);
cachedDataSet = null;
}
}
/**
* 将多张表的数据根据多个query导出到文件,文件格式为flatXMLFile<BR>
* 表名与query的index相对应
*
* @param tableName 表名
* @param sqlQuery SQL查询语句
* @param flatXMLFile 导出的文件
* @throws DataSetException
* @throws FileNotFoundException
* @throws SQLException
* @throws IOException
*/
public void exportFromDBToFile(String tableName[],String sqlQuery[],File flatXMLFile) throws DataSetException, FileNotFoundException, SQLException, IOException {
if(tableName.length != sqlQuery.length)
throw new RuntimeException("tableName.length="+tableName.length+" != query.length="+sqlQuery.length);
for(int i = 0; i < tableName.length; i++) {
exportFromDBToFile(tableName,sqlQuery,flatXMLFile);
}
}
/**
* 将表的数据根据query导出到文件,文件格式为flatXMLFile
*
* @param tableName 表名
* @param sqlQuery SQL查询语句
* @param flatXMLFile 导出的文件
* @throws DataSetException
* @throws FileNotFoundException
* @throws SQLException
* @throws IOException
*/
public void exportFromDBToFile(String tableName,String sqlQuery,File flatXMLFile) throws DataSetException, FileNotFoundException, SQLException, IOException {
QueryDataSet partialDataSet = new QueryDataSet(getDatabaseConnection());
partialDataSet.addTable(tableName,sqlQuery);
FlatXmlDataSet.write(partialDataSet, new FileOutputStream(flatXMLFile));
}
/**
* 将单个表的数据导出为一个DBUnit的FlatXMLFile
*
* @see insertTestData()
* @param tableNames 需要导入数据的表名s
* @param flatXMLFile 导出的文件
* @throws SQLException
* @throws DataSetException
* @throws FileNotFoundException
* @throws IOException
*/
public void exportFromDBToFile(String tableName,File flatXMLFile) throws DataSetException, FileNotFoundException, SQLException, IOException {
exportFromDBToFile(new String[]{tableName},flatXMLFile);
}
/**
* 将多张表的数据导出为一个DBUnit的FlatXMLFile
*
* @see insertTestData()
* @param tableNames 需要导入数据的表名s
* @param flatXMLFile 导出的文件
* @throws SQLException
* @throws DataSetException
* @throws FileNotFoundException
* @throws IOException
*/
public void exportFromDBToFile(String[] tableNames,File flatXMLFile) throws SQLException, DataSetException, FileNotFoundException, IOException {
QueryDataSet partialDataSet = new QueryDataSet(getDatabaseConnection());
for(int i = 0; i < tableNames.length; i++ ) {
partialDataSet.addTable(tableNames[i]);
}
FlatXmlDataSet.write(partialDataSet, new FileOutputStream(flatXMLFile));
}
/**
* 设置数据源
*
* @param d
* @throws Exception
*/
public void setDataSource(DataSource d) throws Exception {
this.dataSource = d;
initDbUnitConnection();
}
public void setDataSource(DataSource d,String jdbcSchema) throws Exception {
setJdbcSchema(jdbcSchema);
setDataSource(d);
}
public void setDatabaseConnection(DatabaseConnection databaseConnection) {
this.databaseConnection = databaseConnection;
}
/**
* 释放数据库连接资源
*
* @throws SQLException
*/
public void close() throws SQLException {
if (databaseConnection != null)
databaseConnection.close();
}
/**
* 初始化DbUnit的数据库连接
*
* @throws Exception
* 数据库连接出错
*/
private void initDbUnitConnection() throws Exception {
if (dataSource != null)
databaseConnection = (new DatabaseConnection(dataSource.getConnection(),jdbcSchema));
}
private DatabaseConnection getDatabaseConnection() {
if (databaseConnection == null)
throw new RuntimeException("请先设置数据源,执行setDataSource()");
return databaseConnection;
}
}