package org.apache.blur.thrift;
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.blur.thirdparty.thrift_0_9_0.async.AsyncMethodCallback;
import org.apache.blur.thirdparty.thrift_0_9_0.async.TAsyncClient;
import org.apache.blur.thirdparty.thrift_0_9_0.async.TAsyncClientManager;
import org.apache.blur.thirdparty.thrift_0_9_0.protocol.TBinaryProtocol;
import org.apache.blur.thirdparty.thrift_0_9_0.protocol.TProtocolFactory;
import org.apache.blur.thirdparty.thrift_0_9_0.transport.TNonblockingSocket;
import org.apache.blur.thirdparty.thrift_0_9_0.transport.TNonblockingTransport;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
public class AsyncClientPool {
public static final Log LOG = LogFactory.getLog(AsyncClientPool.class);
public static final int DEFAULT_MAX_CONNECTIONS_PER_HOST = 5;
public static final int DEFAULT_CONNECTION_TIMEOUT = 60000;
private int _maxConnectionsPerHost;
private long _timeout;
private long _pollTime = 5;
private Map<String, AtomicInteger> _numberOfConnections = new ConcurrentHashMap<String, AtomicInteger>();
private Map<Connection, BlockingQueue<TAsyncClient>> _clientMap = new ConcurrentHashMap<Connection, BlockingQueue<TAsyncClient>>();
private Map<String, Constructor<?>> _constructorCache = new ConcurrentHashMap<String, Constructor<?>>();
private TProtocolFactory _protocolFactory;
private TAsyncClientManager _clientManager;
private Collection<TNonblockingTransport> _transports = new LinkedBlockingQueue<TNonblockingTransport>();
private Field _transportField;
private Random random = new Random();
public AsyncClientPool() throws IOException {
this(DEFAULT_MAX_CONNECTIONS_PER_HOST, DEFAULT_CONNECTION_TIMEOUT);
}
public AsyncClientPool(int maxConnectionsPerHost, int connectionTimeout) throws IOException {
_clientManager = new TAsyncClientManager();
_protocolFactory = new TBinaryProtocol.Factory();
_maxConnectionsPerHost = maxConnectionsPerHost;
_timeout = connectionTimeout;
try {
_transportField = TAsyncClient.class.getDeclaredField("___transport");
_transportField.setAccessible(true);
} catch (SecurityException e) {
throw new RuntimeException(e);
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
}
public void close() {
_clientManager.stop();
for (TNonblockingTransport transport : _transports) {
transport.close();
}
}
/**
* Gets a client instance that implements the AsyncIface interface that
* connects to the given connection string.
*
* @param <T>
* @param asyncIfaceClass
* the AsyncIface interface to pool.
* @param connectionStr
* the connection string.
* @return the client instance.
*/
@SuppressWarnings("unchecked")
public <T> T getClient(final Class<T> asyncIfaceClass, final String connectionStr) {
List<Connection> connections = BlurClientManager.getConnections(connectionStr);
Collections.shuffle(connections, random);
// randomness ftw
final Connection connection = connections.get(0);
return (T) Proxy.newProxyInstance(asyncIfaceClass.getClassLoader(), new Class[] { asyncIfaceClass }, new InvocationHandler() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
return execute(new AsyncCall(asyncIfaceClass, method, args, connection));
}
});
}
private Object execute(AsyncCall call) throws Exception {
AsyncMethodCallback<?> realCallback = getRealAsyncMethodCallback(call._args);
TAsyncClient client = newClient(call._clazz, call._connection);
AsyncMethodCallback<?> retryingCallback = wrapCallback(realCallback, client, call._connection);
resetArgs(call._args, retryingCallback);
return call._method.invoke(client, call._args);
}
private synchronized BlockingQueue<TAsyncClient> getQueue(Connection connection) {
BlockingQueue<TAsyncClient> blockingQueue = _clientMap.get(connection);
if (blockingQueue == null) {
blockingQueue = new LinkedBlockingQueue<TAsyncClient>();
_clientMap.put(connection, blockingQueue);
}
return blockingQueue;
}
private void returnClient(Connection connection, TAsyncClient client) throws InterruptedException {
if (!client.hasError()) {
getQueue(connection).put(client);
} else {
AtomicInteger counter = _numberOfConnections.get(connection.getHost());
if (counter != null) {
counter.decrementAndGet();
}
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private AsyncMethodCallback<?> wrapCallback(AsyncMethodCallback<?> realCallback, TAsyncClient client, Connection connectionStr) {
return new ClientPoolAsyncMethodCallback(realCallback, client, this, connectionStr);
}
private void resetArgs(Object[] args, AsyncMethodCallback<?> callback) {
args[args.length - 1] = callback;
}
private AsyncMethodCallback<?> getRealAsyncMethodCallback(Object[] args) {
return (AsyncMethodCallback<?>) args[args.length - 1];
}
private TAsyncClient newClient(Class<?> c, Connection connection) throws InterruptedException {
BlockingQueue<TAsyncClient> blockingQueue = getQueue(connection);
TAsyncClient client = blockingQueue.poll();
if (client != null) {
return client;
}
AtomicInteger counter;
synchronized (_numberOfConnections) {
counter = _numberOfConnections.get(connection.getHost());
if (counter == null) {
counter = new AtomicInteger();
_numberOfConnections.put(connection.getHost(), counter);
}
}
synchronized (counter) {
int numOfConnections = counter.get();
while (numOfConnections >= _maxConnectionsPerHost) {
client = blockingQueue.poll(_pollTime, TimeUnit.MILLISECONDS);
if (client != null) {
return client;
}
LOG.debug("Waiting for client number of connection [" + numOfConnections + "], max connection per host [" + _maxConnectionsPerHost + "]");
numOfConnections = counter.get();
}
LOG.info("Creating a new client for [" + connection + "]");
String name = c.getName();
Constructor<?> constructor = _constructorCache.get(name);
if (constructor == null) {
String clientClassName = name.replace("$AsyncIface", "$AsyncClient");
try {
Class<?> clazz = Class.forName(clientClassName);
constructor = clazz.getConstructor(new Class[] { TProtocolFactory.class, TAsyncClientManager.class, TNonblockingTransport.class });
_constructorCache.put(name, constructor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
try {
TNonblockingSocket transport = newTransport(connection);
client = (TAsyncClient) constructor.newInstance(new Object[] { _protocolFactory, _clientManager, transport });
client.setTimeout(_timeout);
counter.incrementAndGet();
return client;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
private TNonblockingSocket newTransport(Connection connection) throws IOException {
return new TNonblockingSocket(connection.getHost(), connection.getPort());
}
private static class ClientPoolAsyncMethodCallback<T> implements AsyncMethodCallback<T> {
private AsyncMethodCallback<T> _realCallback;
private TAsyncClient _client;
private AsyncClientPool _pool;
private Connection _connection;
public ClientPoolAsyncMethodCallback(AsyncMethodCallback<T> realCallback, TAsyncClient client, AsyncClientPool pool, Connection connection) {
_realCallback = realCallback;
_client = client;
_pool = pool;
_connection = connection;
}
@Override
public void onComplete(T response) {
_realCallback.onComplete(response);
try {
_pool.returnClient(_connection, _client);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
public void onError(Exception exception) {
AtomicInteger counter = _pool._numberOfConnections.get(_connection.getHost());
if (counter != null) {
counter.decrementAndGet();
}
_realCallback.onError(exception);
_pool.closeAndRemoveTransport(_client);
}
}
private static class AsyncCall {
Class<?> _clazz;
Method _method;
Object[] _args;
Connection _connection;
public AsyncCall(Class<?> clazz, Method method, Object[] args, Connection connection) {
_clazz = clazz;
_method = method;
_args = args;
_connection = connection;
}
}
private void closeAndRemoveTransport(TAsyncClient client) {
try {
TNonblockingTransport transport = (TNonblockingTransport) _transportField.get(client);
_transports.remove(transport);
transport.close();
} catch (IllegalArgumentException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}