/*
* Copyright 2013, The Sporting Exchange Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.betfair.cougar.client.socket;
import com.betfair.cougar.client.socket.resolver.NetworkAddressResolver;
import com.betfair.cougar.netutil.nio.ClientHandshake;
import com.betfair.cougar.netutil.nio.NioConfig;
import com.betfair.cougar.netutil.nio.NioLogger;
import com.betfair.cougar.netutil.nio.NioUtils;
import com.betfair.cougar.netutil.nio.message.ProtocolMessage;
import com.betfair.cougar.util.JMXReportingThreadPoolExecutor;
import org.apache.mina.common.*;
import org.apache.mina.transport.socket.nio.SocketConnector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jmx.export.annotation.ManagedAttribute;
import org.springframework.jmx.export.annotation.ManagedResource;
import java.net.SocketAddress;
import java.util.*;
/**
*
*/
@ManagedResource
public class IoSessionFactory {
private static final Logger log = LoggerFactory.getLogger(IoSessionFactory.class);
private final NioLogger logger;
private int handshakeResponseTimeout;
private int reconnectInterval;
private final SocketConnector socketConnector;
private final Object lock = new Object(); // a lock object to protect access to session list and a counter
private volatile int counter = 0;
// Maintains a list of all endpoints to which connections are established
private final Map<SocketAddress, IoSession> sessions = new TreeMap<SocketAddress, IoSession>(new AddressComparator());
private final IoHandler ioHandler;
private final IoFutureListener sessionClosedListener;
private final NioConfig nioConfig;
private volatile boolean keepRunning = false;
private JMXReportingThreadPoolExecutor reconnectExecutor;
private final String hosts;
private SessionRecycler sessionRecycler;
// Maintains a list of endpoints to which connections are being established
private Map<SocketAddress, ReconnectTask> pendingConnections = new HashMap<SocketAddress, ReconnectTask>();
public IoSessionFactory(NioLogger logger,
String hosts,
JMXReportingThreadPoolExecutor executorService,
JMXReportingThreadPoolExecutor reconnectExecutor,
NioConfig config,
IoHandler ioHandler,
IoFutureListener sessionClosedListener,
int reconnectInterval,
int handshakeResponseTimeout,
long sessionRecycleInterval,
NetworkAddressResolver addressResolver) {
this.logger = logger;
this.reconnectInterval = reconnectInterval;
this.handshakeResponseTimeout = handshakeResponseTimeout;
this.hosts = hosts;
this.nioConfig = config;
this.socketConnector = new SocketConnector(executorService.getCorePoolSize(), executorService);
this.socketConnector.setWorkerTimeout(config.getWorkerTimeout());
this.ioHandler = ioHandler;
this.sessionClosedListener = sessionClosedListener;
this.keepRunning = false;
this.reconnectExecutor = reconnectExecutor;
sessionRecycler = new SessionRecycler(this, addressResolver, hosts, sessionRecycleInterval);
}
public boolean isConnected() {
synchronized (lock) {
return !sessions.isEmpty();
}
}
/**
* Returns a list of all server socket addresses to which sessions are already
* established or being established
* @return List of socket addresses
*/
public Set<SocketAddress> getCurrentSessionAddresses() {
Set<SocketAddress> result = new HashSet<SocketAddress>();
synchronized (lock) {
result.addAll(sessions.keySet());
result.addAll(pendingConnections.keySet());
}
return result;
}
public Map<String, String> getConnectedStatus() {
List<IoSession> tmp;
synchronized (lock) {
tmp = new ArrayList<IoSession>(sessions.values());
}
final HashMap<String, String> result = new HashMap<String, String>();
for (IoSession session : tmp) {
final String sessionId = NioUtils.getSessionId(session);
StringBuilder buffer = new StringBuilder();
buffer.append("SessionId=").append(sessionId).append(",")
.append("remoteHost=").append(session.getRemoteAddress()).append(",")
.append("connected=").append(session.isConnected()).append(",")
.append("closing=").append(session.isClosing()).append(",")
.append('\n');
result.put(sessionId, buffer.toString());
}
return result;
}
public void start() {
this.keepRunning = true;
sessionRecycler.initialise();
}
public void stop() {
keepRunning = false; // stop all tasks to reconnect
ArrayList<IoSession> sessionsSnapshot;
synchronized (lock) {
sessionsSnapshot = new ArrayList<IoSession>(sessions.values());
}
for (IoSession session : sessionsSnapshot) { // close each open session
close(session);
}
}
/**
* Rotates via list of currently established sessions
*
* @return an IO session
*/
public IoSession getSession() {
synchronized (lock) {
if (sessions.isEmpty()) {
return null;
} else {
final Object[] keys = sessions.keySet().toArray();
for (int i = 0; i < sessions.size(); i++) { //
counter++;
final int pos = Math.abs(counter % sessions.size());
final IoSession session = sessions.get(keys[pos]);
if (isAvailable(session)) {
return session;
}
}
return null;
}
}
}
/**
* Open a new session to the specified address. If session is being
* opened does nothing
*
* @param endpoint
*/
public void openSession(SocketAddress endpoint) {
synchronized (lock) {
// Submit a reconnect task for this address if one is not already present
if (!pendingConnections.containsKey(endpoint)) {
final ReconnectTask task = new ReconnectTask(endpoint);
pendingConnections.put(endpoint, task);
this.reconnectExecutor.submit(task);
}
}
}
/**
* If there is an active session to the specified endpoint, it will be closed
* If not the reconnection task for the endpoint will be stopped
*
* @param endpoint
* @param reconnect whether to reconnect after closing the current session.
* Only used if the session is active
*/
public void closeSession(SocketAddress endpoint, boolean reconnect) {
synchronized (lock) {
// Submit a reconnect task for this address if one is not already present
if (pendingConnections.containsKey(endpoint)) {
final ReconnectTask task = pendingConnections.get(endpoint);
if (task != null) {
task.stop();
}
} else {
final IoSession ioSession = sessions.get(endpoint);
if (ioSession != null) {
close(ioSession, reconnect);
}
}
}
}
boolean isAvailable(IoSession session) {
return (session.isConnected() // connected
&& !session.isClosing() // close has not been initiated
&& !session.containsAttribute(ProtocolMessage.ProtocolMessageType.SUSPEND.name()) // suspend message has not been received
&& !session.containsAttribute(ProtocolMessage.ProtocolMessageType.DISCONNECT.name())); // disconnect message has not been received
}
private final IoFutureListener serverSideCloseListener =
new IoFutureListener() {
@Override
public void operationComplete(IoFuture future) {
IoSessionFactory.this.close(future.getSession());
}
};
public void close(final IoSession aSession) {
close(aSession, true);
}
public void close(final IoSession aSession, boolean reconnect) {
if (aSession == null) {
return;
}
boolean sessionRemoved = false;
final SocketAddress remoteAddress = aSession.getRemoteAddress();
synchronized (lock) {
final IoSession removed = sessions.remove(remoteAddress);
sessionRemoved = (removed != null);
}
if (sessionRemoved) {
try {
if (!aSession.isClosing()) {
logger.log(NioLogger.LoggingLevel.SESSION, aSession, "IoSessionFactory - Closing session");
aSession.close();
}
} finally {
if (reconnect) {
// Submit a reconnect task for this address if one is not already active
openSession(remoteAddress);
}
}
}
}
public IoSession connect(final SocketAddress endpoint) {
ConnectFuture cf = null;
try {
cf = socketConnector.connect(endpoint, this.ioHandler, this.nioConfig.configureSocketSessionConfig());
} catch (Exception e) {
log.info("Error connecting to " + endpoint, e);
}
if (cf != null) {
cf.join();
if (cf.isConnected()) {
log.info("Connected to " + endpoint);
final IoSession session = cf.getSession();
if (handshake(session)) {
final CloseFuture closeFuture = session.getCloseFuture();
closeFuture.addListener(this.serverSideCloseListener);
closeFuture.addListener(this.sessionClosedListener);
return session;
} else {
log.info("Handshake failed for " + endpoint);
logger.log(NioLogger.LoggingLevel.SESSION, session, "Handshake failed for %s", endpoint);
session.close();
}
} else {
log.info("Failed to connect to " + endpoint);
}
}
return null;
}
private boolean handshake(IoSession session) {
ClientHandshake clientHandshake = (ClientHandshake) session.getAttribute(ClientHandshake.HANDSHAKE);
clientHandshake.await(handshakeResponseTimeout);
session.removeAttribute(ClientHandshake.HANDSHAKE); // not needed anymore
return clientHandshake.successful();
}
// ############################################
private class ReconnectTask implements Runnable {
private SocketAddress socketAddress;
private boolean stop;
private ReconnectTask(SocketAddress socketAddress) {
this.socketAddress = socketAddress;
this.stop = false;
}
public void run() {
IoSession session = null;
long i = 1;
while (keepRunning && !stop) {
session = IoSessionFactory.this.connect(socketAddress);
if (session != null) {
synchronized (lock) {
sessions.put(socketAddress, session);
pendingConnections.remove(socketAddress);
}
return;
}
try {
Thread.sleep((long) (reconnectInterval * (1.0 - Math.pow(0.9, i)) / 0.1)); // based on geometric series sum to plateau 10 times initial value
i++;
} catch (InterruptedException e) {/*ignored*/}
}
synchronized (lock) {
pendingConnections.remove(socketAddress);
}
}
// Stop attempting to connect
public void stop() {
this.stop = true;
}
}
@ManagedAttribute
public int getReconnectInterval() {
return reconnectInterval;
}
/*package*/ void setReconnectInterval(int reconnectInterval) {
this.reconnectInterval = reconnectInterval;
}
@ManagedAttribute
public String getHosts() {
return hosts;
}
/*package*/ SessionRecycler getSessionRecycler() {
return this.sessionRecycler;
}
/**
* Simple comparator used for sorting the list the resolved
* server socket addresses
*/
private class AddressComparator implements Comparator<SocketAddress> {
@Override
public int compare(SocketAddress o1, SocketAddress o2) {
return o2.hashCode() - o1.hashCode();
}
}
}