Package liquibase.statementexecute

Source Code of liquibase.statementexecute.AbstractExecuteTest

package liquibase.statementexecute;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import liquibase.CatalogAndSchema;
import liquibase.changelog.ChangeLogHistoryServiceFactory;
import liquibase.database.Database;
import liquibase.database.DatabaseConnection;
import liquibase.database.jvm.JdbcConnection;
import liquibase.exception.UnexpectedLiquibaseException;
import liquibase.lockservice.LockServiceFactory;
import liquibase.snapshot.SnapshotGeneratorFactory;
import liquibase.structure.core.Table;
import liquibase.datatype.DataTypeFactory;
import liquibase.database.example.ExampleCustomDatabase;
import liquibase.sdk.database.MockDatabase;
import liquibase.database.core.UnsupportedDatabase;
import liquibase.executor.ExecutorService;
import liquibase.sql.Sql;
import liquibase.sqlgenerator.SqlGeneratorFactory;
import liquibase.statement.SqlStatement;
import liquibase.test.TestContext;
import liquibase.test.DatabaseTestContext;
import liquibase.exception.DatabaseException;

import org.junit.After;

public abstract class AbstractExecuteTest {

    private Set<Class<? extends Database>> testedDatabases = new HashSet<Class<? extends Database>>();
    protected SqlStatement statementUnderTest;

    @After
    public void reset() {
        for (Database database : TestContext.getInstance().getAllDatabases()) {
            if (database.getConnection() != null) {
                try {
                    database.rollback();
                } catch (DatabaseException e) {
                    //ok
                }
            }
        }
        testedDatabases = new HashSet<Class<? extends Database>>();
        this.statementUnderTest = null;

        SnapshotGeneratorFactory.resetAll();
    }

    protected abstract List<? extends SqlStatement> setupStatements(Database database);

    protected void testOnAll(String expectedSql) throws Exception {
        test(expectedSql, null, null);
    }

    protected void assertCorrectOnRest(String expectedSql) throws Exception {
        assertCorrect(expectedSql);
    }

    protected void assertCorrect(String expectedSql, Class<? extends Database>... includeDatabases) throws Exception {
        assertCorrect(new String[]{expectedSql}, includeDatabases);
    }

    protected void assertCorrect(String[] expectedSql, Class<? extends Database>... includeDatabases) throws Exception {
        assertNotNull(statementUnderTest);

        test(expectedSql, includeDatabases, null);
    }

    public void testOnAllExcept(String expectedSql, Class<? extends Database>... excludedDatabases) throws Exception {
        test(expectedSql, null, excludedDatabases);
    }

    private void test(String expectedSql, Class<? extends Database>[] includeDatabases, Class<? extends Database>[] excludeDatabases) throws Exception {
        test(new String[]{expectedSql}, includeDatabases, excludeDatabases);
    }

    private void test(String[] expectedSql, Class<? extends Database>[] includeDatabases, Class<? extends Database>[] excludeDatabases) throws Exception {

        if (expectedSql != null) {
            for (Database database : TestContext.getInstance().getAllDatabases()) {
                if (shouldTestDatabase(database, includeDatabases, excludeDatabases)) {
                    testedDatabases.add(database.getClass());

                    if (database.getConnection() != null) {
                        ChangeLogHistoryServiceFactory.getInstance().getChangeLogService(database).init();
                        LockServiceFactory.getInstance().getLockService(database).init();
                    }

                    Sql[] sql = SqlGeneratorFactory.getInstance().generateSql(statementUnderTest, database);

                    assertNotNull("Null SQL for " + database, sql);
                    assertEquals("Unexpected number of  SQL statements for " + database, expectedSql.length, sql.length);

                    int index = 0;
                    for (String convertedSql : expectedSql) {
                        convertedSql = replaceEscaping(convertedSql, database);
                        convertedSql = replaceDatabaseClauses(convertedSql, database);
                        convertedSql = replaceStandardTypes(convertedSql, database);

                        assertEquals("Incorrect SQL for " + database.getClass().getName(), convertedSql.toLowerCase().trim(), sql[index].toSql().toLowerCase());
                        index++;
                    }
                }
            }
        }

        resetAvailableDatabases();
        for (Database availableDatabase : DatabaseTestContext.getInstance().getAvailableDatabases()) {
            Statement statement = ((JdbcConnection) availableDatabase.getConnection()).getUnderlyingConnection().createStatement();
            if (shouldTestDatabase(availableDatabase, includeDatabases, excludeDatabases)) {
                String sqlToRun = SqlGeneratorFactory.getInstance().generateSql(statementUnderTest, availableDatabase)[0].toSql();
                try {
                    statement.execute(sqlToRun);
                } catch (Exception e) {
                    System.out.println("Failed to execute against " + availableDatabase.getShortName() + ": " + sqlToRun);
                    throw e;

                }
            }
        }
    }

    private String replaceStandardTypes(String convertedSql, Database database) {
        convertedSql = replaceType("int", convertedSql, database);
        convertedSql = replaceType("datetime", convertedSql, database);
        convertedSql = replaceType("boolean", convertedSql, database);

        convertedSql = convertedSql.replaceAll("FALSE", DataTypeFactory.getInstance().fromDescription("boolean", database).objectToSql(false, database));
        convertedSql = convertedSql.replaceAll("TRUE", DataTypeFactory.getInstance().fromDescription("boolean", database).objectToSql(true, database));
        convertedSql = convertedSql.replaceAll("NOW\\(\\)", database.getCurrentDateTimeFunction());

        return convertedSql;
    }

    private String replaceType(String type, String baseString, Database database) {
        return baseString.replaceAll(" " + type + " ", " " + DataTypeFactory.getInstance().fromDescription(type, database).toDatabaseDataType(database).toString() + " ")
                .replaceAll(" " + type + ",", " " + DataTypeFactory.getInstance().fromDescription(type, database).toDatabaseDataType(database).toString() + ",");
    }

    private String replaceDatabaseClauses(String convertedSql, Database database) {
        return convertedSql.replaceFirst("auto_increment_clause", database.getAutoIncrementClause(null, null));
    }

    private boolean shouldTestDatabase(Database database, Class<? extends Database>[] includeDatabases, Class<? extends Database>[] excludeDatabases) {
        if (database instanceof MockDatabase || database instanceof ExampleCustomDatabase || database instanceof UnsupportedDatabase) {
            return false;
        }
        if (!SqlGeneratorFactory.getInstance().supports(statementUnderTest, database)
                || SqlGeneratorFactory.getInstance().validate(statementUnderTest, database).hasErrors()) {
            return false;
        }

        boolean shouldInclude = true;
        if (includeDatabases != null && includeDatabases.length > 0) {
            shouldInclude = Arrays.asList(includeDatabases).contains(database.getClass());
        }

        boolean shouldExclude = false;
        if (excludeDatabases != null && excludeDatabases.length > 0) {
            shouldExclude = Arrays.asList(excludeDatabases).contains(database.getClass());
        }

        return !shouldExclude && shouldInclude && !testedDatabases.contains(database.getClass());


    }

    private String replaceEscaping(String expectedSql, Database database) {
        String convertedSql = expectedSql;
        int lastIndex = 0;
        while ((lastIndex = convertedSql.indexOf("[", lastIndex)) >= 0) {
            String objectName = convertedSql.substring(lastIndex + 1, convertedSql.indexOf("]", lastIndex));
            try {
                convertedSql = convertedSql.replace("[" + objectName + "]", database.escapeObjectName(objectName, Table.class));
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
            lastIndex++;
        }

        return convertedSql;
    }

    public void resetAvailableDatabases() throws Exception {
        for (Database database : DatabaseTestContext.getInstance().getAvailableDatabases()) {
            DatabaseConnection connection = database.getConnection();
            Statement connectionStatement = ((JdbcConnection) connection).getUnderlyingConnection().createStatement();

            try {
                database.dropDatabaseObjects(CatalogAndSchema.DEFAULT);
            } catch (Throwable e) {
                throw new UnexpectedLiquibaseException("Error dropping objects for database "+database.getShortName(), e);
            }
            try {
                connectionStatement.executeUpdate("drop table " + database.escapeTableName(database.getLiquibaseCatalogName(), database.getLiquibaseSchemaName(), database.getDatabaseChangeLogLockTableName()));
            } catch (SQLException e) {
                ;
            }
            connection.commit();
            try {
                connectionStatement.executeUpdate("drop table " + database.escapeTableName(database.getLiquibaseCatalogName(), database.getLiquibaseSchemaName(), database.getDatabaseChangeLogTableName()));
            } catch (SQLException e) {
                ;
            }
            connection.commit();

            if (database.supportsSchemas()) {
                database.dropDatabaseObjects(new CatalogAndSchema(DatabaseTestContext.ALT_CATALOG, DatabaseTestContext.ALT_SCHEMA));
                connection.commit();

                try {
                    connectionStatement.executeUpdate("drop table " + database.escapeTableName(DatabaseTestContext.ALT_CATALOG, DatabaseTestContext.ALT_SCHEMA, database.getDatabaseChangeLogLockTableName()));
                } catch (SQLException e) {
                    //ok
                }
                connection.commit();
                try {
                    connectionStatement.executeUpdate("drop table " + database.escapeTableName(DatabaseTestContext.ALT_CATALOG, DatabaseTestContext.ALT_SCHEMA, database.getDatabaseChangeLogTableName()));
                } catch (SQLException e) {
                    //ok
                }
                connection.commit();
            }

            List<? extends SqlStatement> setupStatements = setupStatements(database);
            if (setupStatements != null) {
                for (SqlStatement statement : setupStatements) {
                    ExecutorService.getInstance().getExecutor(database).execute(statement);
                }
            }
            connectionStatement.close();
        }
    }

}
TOP

Related Classes of liquibase.statementexecute.AbstractExecuteTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.