Package com.alexkasko.springjdbc.parallel

Source Code of com.alexkasko.springjdbc.parallel.ParallelQueriesIterator$SubmitFun

package com.alexkasko.springjdbc.parallel;

import com.alexkasko.springjdbc.parallel.accessor.DataSourceAccessor;
import com.alexkasko.springjdbc.parallel.accessor.RoundRobinAccessor;
import com.google.common.base.Function;
import com.google.common.collect.*;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.DataAccessException;
import org.springframework.jdbc.core.*;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;

import javax.sql.DataSource;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static com.alexkasko.springjdbc.parallel.NamedParameterJdbcTemplateFunction.NPJT_FUNCTION;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static org.springframework.util.StringUtils.hasText;

/**
* Executes single SQL query to multiple data sources in parallel using provided executor.
* Provides data to client as iterator, {@link ArrayBlockingQueue} is used for temporary buffering.
* Iteration will block awaiting data loaded from sources.
* Typical usage is to get new instance somewhere (spring prototype bean etc.), provide query params
* with <code>start</code> method and iterate over until end.
* Data source exceptions will be propagated as runtime exceptions thrown on 'next()' or 'hasNext()' call.
* All parallel queries will be cancelled on one query error.
* Iterator takes sql in the same format as {@code NamedParameterJdbcTemplate} does (with {@code :palceholders})
* and map declared parameters to provided values using the same methods as {@code NamedParameterJdbcTemplate}.
* <b>NOT</b> thread-safe (tbd: specify points that break thread safety), instance may be reused calling <code>start</code> method, but only in one thread simultaneously.
*
* @author  alexkasko
* Date: 6/8/12
* @see ParallelQueriesListener
* @see com.alexkasko.springjdbc.parallel.accessor.DataSourceAccessor
*/

public class ParallelQueriesIterator<T> extends AbstractIterator<T> {
    private static final Log logger = LogFactory.getLog(ParallelQueriesIterator.class);

    private final Object endOfDataObject = new Object();
    private final ExceptionHolder exceptionHolder = new ExceptionHolder();

    private final DataSourceAccessor<?, ?> sources;
    private final String sql;
    private final RowMapperFactory<T, ?> mapperFactory;
    private final ExecutorService executor;
    // was made non-generic to allow endOfDataObject
    private final ArrayBlockingQueue<Object> dataQueue;
    private final List<ParallelQueriesListener> listeners = Lists.newArrayList();
    private final int maxDataPollWaitSeconds;

    private AtomicBoolean started = new AtomicBoolean(false);
    private AtomicInteger sourcesRemained = new AtomicInteger(0);
    // this field needed anyway in start method, so we'll use it in workers instead of thread.isInterrupted
    private final AtomicBoolean cancelled = new AtomicBoolean(false);
    private final ConcurrentHashMap<String, Statement> activeStatements = new ConcurrentHashMap<String, Statement>();
    private ImmutableList<Future<?>> futures;

    /**
     * Shortcut constructor, uses {@link RoundRobinAccessor}, new cached thread pool,
     * and {@code JdbcTemplate}s with default settings
     *
     *
     * @param sources list of data sources, will be used in round-robin mode
     * @param sql query to execute using JdbcTemplate
     * @param mapper will be used to get data from result sets
     */
    public ParallelQueriesIterator(Collection<DataSource> sources, String sql, RowMapper<T> mapper) {
        this(RoundRobinAccessor.of(Collections2.transform(sources, NPJT_FUNCTION)),
                sql, Executors.newCachedThreadPool(), mapper, 1024, 60);
    }

    /**
     * Main constructor, uses {@code RowMapper} instead of mapper factory
     *
     * @param sources data sources (wrapped into JDBC templates) accessor
     * @param sql query to execute using JdbcOperations
     * @param executor executor service to run parallel queries into
     * @param mapper will be used to get data from result sets
     * @param bufferSize size of ArrayBlockingQueue data buffer
     */
    public ParallelQueriesIterator(DataSourceAccessor<?, ?> sources,
                                   String sql, ExecutorService executor, RowMapper<T> mapper, int bufferSize, int maxDataPollWaitSeconds) {
        this(sources, sql, executor, SingletoneRowMapperFactory.of(mapper), bufferSize, maxDataPollWaitSeconds);
    }

    /**
     * Most detailed constructor
     *
     * @param sources data sources (wrapped into JDBC templates) accessor
     * @param sql query to execute using JdbcOperations
     * @param mapperFactory will be used to get data from result sets
     * @param executor executor service to run parallel queries into
     * @param bufferSize size of ArrayBlockingQueue data buffer
     */
    public ParallelQueriesIterator(DataSourceAccessor<?, ?> sources,
                                   String sql, ExecutorService executor, RowMapperFactory<T, ?> mapperFactory, int bufferSize, int maxDataPollWaitSeconds) {
        checkNotNull(sources, "Provided data source accessor is null");
        checkArgument(sources.size() > 0, "No data sources provided");
        checkArgument(hasText(sql), "Provided sql query is blank");
        checkNotNull(executor, "Provided executor is null");
        checkNotNull(mapperFactory, "Provided row mapper factory is null");
        checkArgument(bufferSize > 0, "Buffer size mat be positive, but was: '%s'", bufferSize);
        this.sources = sources;
        this.sql = sql;
        this.mapperFactory = mapperFactory;
        this.executor = executor;
        this.dataQueue = new ArrayBlockingQueue<Object>(bufferSize);
        this.maxDataPollWaitSeconds = maxDataPollWaitSeconds;
    }

    /**
     * Starts parallel query execution in data sources. May be called multiple times to reuse iterator instance,
     * Must be called only on exhausted iterators (which have all worker threads finished).
     * May not be called on cancelled iterators.
     *
     * @param params query params
     * @return iterator itself
     */
    public ParallelQueriesIterator<T> start(Collection<? extends SqlParameterSource> params) {
        checkNotNull(params, "Provided parameters collection is null");
        checkArgument(params.size() > 0, "Provided collection is empty");
        checkState(!cancelled.get(), "This iterator is cancelled and cannot be restarted");
        this.dataQueue.clear();
        this.sourcesRemained.set(params.size());
        this.futures = ImmutableList.copyOf(Collections2.transform(params, new SubmitFun()));
        this.started.set(true);
        return this;
    }

    /**
     * Cancels queries processing in all sources, by setting cancelled state, then cancelling all active
     * prepared statements (for workers blocked on query completion) and then interrupting
     * all threads (for workers blocked on result queue).
     * May be called from another thread. Subsequent calls do nothing.
     *
     * @return count of threads that were actually interrupted in processing
     */
    public int cancel() {
        if(!started.get()) return 0;
        // set cancelled state
        boolean wasCancelled = cancelled.getAndSet(true);
        if(wasCancelled) return 0;
        exceptionHolder.set(new ParallelQueriesException("cancelled"));
        // cancel statements
        for(Statement st : activeStatements.values()) {
            cancelStatement(st);
        }
        int res = 0;
        // cancel threads
        for(Future<?> fu : futures) {
            if(fu.cancel(true)) res += 1;
        }
        return res;
    }

    /**
     * Register listener for "query success" and "query error" events
     *
     * @param listener data source query events will be reported to this listener
     * @return iterator itself
     */
    public ParallelQueriesIterator<T> addListener(ParallelQueriesListener listener) {
        this.listeners.add(listener);
        return this;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder();
        sb.append("ParallelQueriesIterator");
        sb.append("{endOfDataObject=").append(endOfDataObject);
        sb.append(", sources=").append(sources);
        sb.append(", sql='").append(sql).append('\'');
        sb.append(", mapperFactory=").append(mapperFactory);
        sb.append(", executor=").append(executor);
        sb.append(", started=").append(started);
        sb.append(", sourcesRemained=").append(sourcesRemained);
        sb.append(", cancelled=").append(cancelled);
        sb.append('}');
        return sb.toString();
    }

    /**
     * {@inheritDoc}
     */
    @Override
    @SuppressWarnings("unchecked")
    protected T computeNext() {
        checkState(started.get(), "Iterator wasn't started, call 'start' method first");
        Object ob;
        while(endOfDataObject == (ob = takeData())) {
            if(0 == sourcesRemained.decrementAndGet()) return endOfData();
        }
        return (T) ob;
    }

    private void putData(Object data) {
        try {
            dataQueue.put(data);
        } catch(InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    private Object takeData() {
        try {
            for(int i=0; i< maxDataPollWaitSeconds; i++) {
                checkException();
                Object res = dataQueue.poll(1, TimeUnit.SECONDS);
                if(null != res) return res;
            }
            throw new ParallelQueriesException(
                    "Data wait timeout exceeded: '" + maxDataPollWaitSeconds + " seconds', alive workers count: '" + sourcesRemained.get() + "'");
        } catch(InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new ParallelQueriesException(e);
        }
    }

    private void checkException() {
        RuntimeException workerException = exceptionHolder.get();
        if (null != workerException) {
            cancel();
            throw workerException;
        }
    }

    private void cancelStatement(Statement stmt) {
        try {
            stmt.cancel();
        } catch(Exception e) { // driver may throw different exceptions
            logger.warn("Exception thrown on statement cancelling: '" + stmt + "'", e);
        }
    }

    private class Worker implements Runnable {
        private final NamedParameterJdbcOperations npjo;
        private final SqlParameterSource params;

        private Worker(NamedParameterJdbcOperations npjo, SqlParameterSource params) {
            this.npjo = npjo;
            this.params = params;
        }

        @Override
        @SuppressWarnings("unchecked") // workaround to
        // keep RowMapperFactory fully typed without additional
        // generic arguments to iterator itself
        public void run() {
            String registryKey = Thread.currentThread().getName();
            try {
                RowMapperFactory ungeneric = mapperFactory;
                RowMapper<T> mapper = ungeneric.produce(params);
                Extractor extractor = new Extractor(mapper);
                // sql parameters processing is the same as in NamedParameterJdbcTemplate
                PreparedStatementCreator psc = new CancellableStatementCreator(registryKey, activeStatements, sql, params);
                npjo.getJdbcOperations().query(psc, extractor);
                for(ParallelQueriesListener li : listeners) li.success(npjo, sql, params);
            } catch (Throwable e) { // we do not believe to JDBC drivers' error reporting
                ParallelQueriesException pqe = e instanceof ParallelQueriesException ? (ParallelQueriesException) e : new ParallelQueriesException(e);
                exceptionHolder.set(pqe);
                for(ParallelQueriesListener li : listeners) li.error(npjo, sql, params, e);
            } finally {
                activeStatements.remove(registryKey);
            }
        }
    }

    private class Extractor implements ResultSetExtractor<Void> {
        private final RowMapper<T> mapper;

        private Extractor(RowMapper<T> mapper) {
            this.mapper = mapper;
        }

        @Override
        public Void extractData(ResultSet rs) throws SQLException, DataAccessException {
            try {
                int rowNum = 0;
                while(rs.next()) {
                    if(cancelled.get()) return null;
                    Object obj = mapper.mapRow(rs, rowNum++);
                    putData(obj);
                }
                if(cancelled.get()) return null;
                putData(endOfDataObject);
                return null;
            } catch(Throwable e) { // we do not believe to JDBC drivers' error reporting
                if(e instanceof SQLException) throw (SQLException) e;
                throw new SQLException(e);
            }
        }
    }

    private class SubmitFun implements Function<SqlParameterSource, Future<?>> {
        @Override
        @SuppressWarnings("unchecked") // workaround to
        // keep accessors fully typed without additional
        // generic arguments to iterator itself
        public Future<?> apply(SqlParameterSource params) {
            DataSourceAccessor ungeneric = sources;
            NamedParameterJdbcOperations npjo = ungeneric.get(params);
            Worker worker = new Worker(npjo, params);
            return executor.submit(worker);
        }
    }
}
TOP

Related Classes of com.alexkasko.springjdbc.parallel.ParallelQueriesIterator$SubmitFun

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.