package liquibase.statementexecute;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import liquibase.CatalogAndSchema;
import liquibase.changelog.ChangeLogHistoryServiceFactory;
import liquibase.database.Database;
import liquibase.database.DatabaseConnection;
import liquibase.database.jvm.JdbcConnection;
import liquibase.exception.UnexpectedLiquibaseException;
import liquibase.lockservice.LockServiceFactory;
import liquibase.snapshot.SnapshotGeneratorFactory;
import liquibase.structure.core.Table;
import liquibase.datatype.DataTypeFactory;
import liquibase.database.example.ExampleCustomDatabase;
import liquibase.sdk.database.MockDatabase;
import liquibase.database.core.UnsupportedDatabase;
import liquibase.executor.ExecutorService;
import liquibase.sql.Sql;
import liquibase.sqlgenerator.SqlGeneratorFactory;
import liquibase.statement.SqlStatement;
import liquibase.test.TestContext;
import liquibase.test.DatabaseTestContext;
import liquibase.exception.DatabaseException;
import org.junit.After;
public abstract class AbstractExecuteTest {
private Set<Class<? extends Database>> testedDatabases = new HashSet<Class<? extends Database>>();
protected SqlStatement statementUnderTest;
@After
public void reset() {
for (Database database : TestContext.getInstance().getAllDatabases()) {
if (database.getConnection() != null) {
try {
database.rollback();
} catch (DatabaseException e) {
//ok
}
}
}
testedDatabases = new HashSet<Class<? extends Database>>();
this.statementUnderTest = null;
SnapshotGeneratorFactory.resetAll();
}
protected abstract List<? extends SqlStatement> setupStatements(Database database);
protected void testOnAll(String expectedSql) throws Exception {
test(expectedSql, null, null);
}
protected void assertCorrectOnRest(String expectedSql) throws Exception {
assertCorrect(expectedSql);
}
protected void assertCorrect(String expectedSql, Class<? extends Database>... includeDatabases) throws Exception {
assertCorrect(new String[]{expectedSql}, includeDatabases);
}
protected void assertCorrect(String[] expectedSql, Class<? extends Database>... includeDatabases) throws Exception {
assertNotNull(statementUnderTest);
test(expectedSql, includeDatabases, null);
}
public void testOnAllExcept(String expectedSql, Class<? extends Database>... excludedDatabases) throws Exception {
test(expectedSql, null, excludedDatabases);
}
private void test(String expectedSql, Class<? extends Database>[] includeDatabases, Class<? extends Database>[] excludeDatabases) throws Exception {
test(new String[]{expectedSql}, includeDatabases, excludeDatabases);
}
private void test(String[] expectedSql, Class<? extends Database>[] includeDatabases, Class<? extends Database>[] excludeDatabases) throws Exception {
if (expectedSql != null) {
for (Database database : TestContext.getInstance().getAllDatabases()) {
if (shouldTestDatabase(database, includeDatabases, excludeDatabases)) {
testedDatabases.add(database.getClass());
if (database.getConnection() != null) {
ChangeLogHistoryServiceFactory.getInstance().getChangeLogService(database).init();
LockServiceFactory.getInstance().getLockService(database).init();
}
Sql[] sql = SqlGeneratorFactory.getInstance().generateSql(statementUnderTest, database);
assertNotNull("Null SQL for " + database, sql);
assertEquals("Unexpected number of SQL statements for " + database, expectedSql.length, sql.length);
int index = 0;
for (String convertedSql : expectedSql) {
convertedSql = replaceEscaping(convertedSql, database);
convertedSql = replaceDatabaseClauses(convertedSql, database);
convertedSql = replaceStandardTypes(convertedSql, database);
assertEquals("Incorrect SQL for " + database.getClass().getName(), convertedSql.toLowerCase().trim(), sql[index].toSql().toLowerCase());
index++;
}
}
}
}
resetAvailableDatabases();
for (Database availableDatabase : DatabaseTestContext.getInstance().getAvailableDatabases()) {
Statement statement = ((JdbcConnection) availableDatabase.getConnection()).getUnderlyingConnection().createStatement();
if (shouldTestDatabase(availableDatabase, includeDatabases, excludeDatabases)) {
String sqlToRun = SqlGeneratorFactory.getInstance().generateSql(statementUnderTest, availableDatabase)[0].toSql();
try {
statement.execute(sqlToRun);
} catch (Exception e) {
System.out.println("Failed to execute against " + availableDatabase.getShortName() + ": " + sqlToRun);
throw e;
}
}
}
}
private String replaceStandardTypes(String convertedSql, Database database) {
convertedSql = replaceType("int", convertedSql, database);
convertedSql = replaceType("datetime", convertedSql, database);
convertedSql = replaceType("boolean", convertedSql, database);
convertedSql = convertedSql.replaceAll("FALSE", DataTypeFactory.getInstance().fromDescription("boolean", database).objectToSql(false, database));
convertedSql = convertedSql.replaceAll("TRUE", DataTypeFactory.getInstance().fromDescription("boolean", database).objectToSql(true, database));
convertedSql = convertedSql.replaceAll("NOW\\(\\)", database.getCurrentDateTimeFunction());
return convertedSql;
}
private String replaceType(String type, String baseString, Database database) {
return baseString.replaceAll(" " + type + " ", " " + DataTypeFactory.getInstance().fromDescription(type, database).toDatabaseDataType(database).toString() + " ")
.replaceAll(" " + type + ",", " " + DataTypeFactory.getInstance().fromDescription(type, database).toDatabaseDataType(database).toString() + ",");
}
private String replaceDatabaseClauses(String convertedSql, Database database) {
return convertedSql.replaceFirst("auto_increment_clause", database.getAutoIncrementClause(null, null));
}
private boolean shouldTestDatabase(Database database, Class<? extends Database>[] includeDatabases, Class<? extends Database>[] excludeDatabases) {
if (database instanceof MockDatabase || database instanceof ExampleCustomDatabase || database instanceof UnsupportedDatabase) {
return false;
}
if (!SqlGeneratorFactory.getInstance().supports(statementUnderTest, database)
|| SqlGeneratorFactory.getInstance().validate(statementUnderTest, database).hasErrors()) {
return false;
}
boolean shouldInclude = true;
if (includeDatabases != null && includeDatabases.length > 0) {
shouldInclude = Arrays.asList(includeDatabases).contains(database.getClass());
}
boolean shouldExclude = false;
if (excludeDatabases != null && excludeDatabases.length > 0) {
shouldExclude = Arrays.asList(excludeDatabases).contains(database.getClass());
}
return !shouldExclude && shouldInclude && !testedDatabases.contains(database.getClass());
}
private String replaceEscaping(String expectedSql, Database database) {
String convertedSql = expectedSql;
int lastIndex = 0;
while ((lastIndex = convertedSql.indexOf("[", lastIndex)) >= 0) {
String objectName = convertedSql.substring(lastIndex + 1, convertedSql.indexOf("]", lastIndex));
try {
convertedSql = convertedSql.replace("[" + objectName + "]", database.escapeObjectName(objectName, Table.class));
} catch (Exception e) {
throw new RuntimeException(e);
}
lastIndex++;
}
return convertedSql;
}
public void resetAvailableDatabases() throws Exception {
for (Database database : DatabaseTestContext.getInstance().getAvailableDatabases()) {
DatabaseConnection connection = database.getConnection();
Statement connectionStatement = ((JdbcConnection) connection).getUnderlyingConnection().createStatement();
try {
database.dropDatabaseObjects(CatalogAndSchema.DEFAULT);
} catch (Throwable e) {
throw new UnexpectedLiquibaseException("Error dropping objects for database "+database.getShortName(), e);
}
try {
connectionStatement.executeUpdate("drop table " + database.escapeTableName(database.getLiquibaseCatalogName(), database.getLiquibaseSchemaName(), database.getDatabaseChangeLogLockTableName()));
} catch (SQLException e) {
;
}
connection.commit();
try {
connectionStatement.executeUpdate("drop table " + database.escapeTableName(database.getLiquibaseCatalogName(), database.getLiquibaseSchemaName(), database.getDatabaseChangeLogTableName()));
} catch (SQLException e) {
;
}
connection.commit();
if (database.supportsSchemas()) {
database.dropDatabaseObjects(new CatalogAndSchema(DatabaseTestContext.ALT_CATALOG, DatabaseTestContext.ALT_SCHEMA));
connection.commit();
try {
connectionStatement.executeUpdate("drop table " + database.escapeTableName(DatabaseTestContext.ALT_CATALOG, DatabaseTestContext.ALT_SCHEMA, database.getDatabaseChangeLogLockTableName()));
} catch (SQLException e) {
//ok
}
connection.commit();
try {
connectionStatement.executeUpdate("drop table " + database.escapeTableName(DatabaseTestContext.ALT_CATALOG, DatabaseTestContext.ALT_SCHEMA, database.getDatabaseChangeLogTableName()));
} catch (SQLException e) {
//ok
}
connection.commit();
}
List<? extends SqlStatement> setupStatements = setupStatements(database);
if (setupStatements != null) {
for (SqlStatement statement : setupStatements) {
ExecutorService.getInstance().getExecutor(database).execute(statement);
}
}
connectionStatement.close();
}
}
}