/*
* Copyright (c) 2008-2013, Hazelcast, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hazelcast.nio;
import com.hazelcast.cluster.BindOperation;
import com.hazelcast.config.SSLConfig;
import com.hazelcast.config.SocketInterceptorConfig;
import com.hazelcast.logging.ILogger;
import com.hazelcast.nio.serialization.Data;
import com.hazelcast.nio.serialization.SerializationContext;
import com.hazelcast.nio.ssl.BasicSSLContextFactory;
import com.hazelcast.nio.ssl.SSLContextFactory;
import com.hazelcast.nio.ssl.SSLSocketChannelWrapper;
import com.hazelcast.util.ConcurrencyUtil;
import com.hazelcast.util.ConstructorFunction;
import com.hazelcast.util.executor.StripedRunnable;
import java.io.IOException;
import java.net.Socket;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
public class TcpIpConnectionManager implements ConnectionManager {
private final ILogger logger;
final int socketReceiveBufferSize;
final int socketSendBufferSize;
private final int socketLingerSeconds;
private final boolean socketKeepAlive;
private final boolean socketNoDelay;
private final ConcurrentMap<Address, Connection> connectionsMap = new ConcurrentHashMap<Address, Connection>(100);
private final ConcurrentMap<Address, ConnectionMonitor> monitors = new ConcurrentHashMap<Address, ConnectionMonitor>(100);
private final Set<Address> connectionsInProgress = Collections.newSetFromMap(new ConcurrentHashMap<Address, Boolean>());
private final Set<ConnectionListener> connectionListeners = new CopyOnWriteArraySet<ConnectionListener>();
private final Set<SocketChannelWrapper> acceptedSockets = Collections.newSetFromMap(new ConcurrentHashMap<SocketChannelWrapper, Boolean>());
private final Set<TcpIpConnection> activeConnections = Collections.newSetFromMap(new ConcurrentHashMap<TcpIpConnection, Boolean>());
private final AtomicInteger allTextConnections = new AtomicInteger();
private final AtomicInteger connectionIdGen = new AtomicInteger();
private volatile boolean live = false;
final IOService ioService;
private final ServerSocketChannel serverSocketChannel;
private final int selectorThreadCount;
private final IOSelector[] inSelectors;
private final IOSelector[] outSelectors;
private final AtomicInteger nextSelectorIndex = new AtomicInteger();
private final MemberSocketInterceptor memberSocketInterceptor;
private final SocketChannelWrapperFactory socketChannelWrapperFactory;
private final int outboundPortCount;
private final LinkedList<Integer> outboundPorts = new LinkedList<Integer>(); // accessed only in synchronized block
private final SerializationContext serializationContext;
private volatile Thread socketAcceptorThread; // accessed only in synchronized block
public TcpIpConnectionManager(IOService ioService, ServerSocketChannel serverSocketChannel) {
this.ioService = ioService;
this.serverSocketChannel = serverSocketChannel;
this.logger = ioService.getLogger(TcpIpConnectionManager.class.getName());
this.socketReceiveBufferSize = ioService.getSocketReceiveBufferSize() * IOService.KILO_BYTE;
this.socketSendBufferSize = ioService.getSocketSendBufferSize() * IOService.KILO_BYTE;
this.socketLingerSeconds = ioService.getSocketLingerSeconds();
this.socketKeepAlive = ioService.getSocketKeepAlive();
this.socketNoDelay = ioService.getSocketNoDelay();
selectorThreadCount = ioService.getSelectorThreadCount();
inSelectors = new IOSelector[selectorThreadCount];
outSelectors = new IOSelector[selectorThreadCount];
final Collection<Integer> ports = ioService.getOutboundPorts();
outboundPortCount = ports == null ? 0 : ports.size();
if (ports != null) {
outboundPorts.addAll(ports);
}
SSLConfig sslConfig = ioService.getSSLConfig();
if (sslConfig != null && sslConfig.isEnabled()) {
socketChannelWrapperFactory = new SSLSocketChannelWrapperFactory(sslConfig);
logger.info("SSL is enabled");
} else {
socketChannelWrapperFactory = new DefaultSocketChannelWrapperFactory();
}
SocketInterceptor implementation = null;
SocketInterceptorConfig sic = ioService.getSocketInterceptorConfig();
if (sic != null && sic.isEnabled()) {
implementation = (SocketInterceptor) sic.getImplementation();
if (implementation == null && sic.getClassName() != null) {
try {
implementation = (SocketInterceptor) Class.forName(sic.getClassName()).newInstance();
} catch (Throwable e) {
logger.severe("SocketInterceptor class cannot be instantiated!" + sic.getClassName(), e);
}
}
if (implementation != null) {
if (!(implementation instanceof MemberSocketInterceptor)) {
logger.severe( "SocketInterceptor must be instance of " + MemberSocketInterceptor.class.getName());
implementation = null;
}
}
}
memberSocketInterceptor = (MemberSocketInterceptor) implementation;
if (memberSocketInterceptor != null) {
logger.info("SocketInterceptor is enabled");
memberSocketInterceptor.init(sic.getProperties());
}
serializationContext = ioService.getSerializationContext();
}
@Override
public int getActiveConnectionCount() {
return activeConnections.size();
}
public int getAllTextConnections() {
return allTextConnections.get();
}
@Override
public int getConnectionCount() {
return connectionsMap.size();
}
@Override
public boolean isSSLEnabled() {
return socketChannelWrapperFactory instanceof SSLSocketChannelWrapperFactory;
}
public void incrementTextConnections() {
allTextConnections.incrementAndGet();
}
public SerializationContext getSerializationContext() {
return serializationContext;
}
interface SocketChannelWrapperFactory {
SocketChannelWrapper wrapSocketChannel(SocketChannel socketChannel, boolean client) throws Exception;
}
static class DefaultSocketChannelWrapperFactory implements SocketChannelWrapperFactory {
public SocketChannelWrapper wrapSocketChannel(SocketChannel socketChannel, boolean client) throws Exception {
return new DefaultSocketChannelWrapper(socketChannel);
}
}
class SSLSocketChannelWrapperFactory implements SocketChannelWrapperFactory {
final SSLContextFactory sslContextFactory;
SSLSocketChannelWrapperFactory(SSLConfig sslConfig) {
if (CipherHelper.isSymmetricEncryptionEnabled(ioService)) {
throw new RuntimeException("SSL and SymmetricEncryption cannot be both enabled!");
}
SSLContextFactory sslContextFactoryObject = (SSLContextFactory) sslConfig.getFactoryImplementation();
try {
String factoryClassName = sslConfig.getFactoryClassName();
if (sslContextFactoryObject == null && factoryClassName != null) {
sslContextFactoryObject = (SSLContextFactory) Class.forName(factoryClassName).newInstance();
}
if (sslContextFactoryObject == null) {
sslContextFactoryObject = new BasicSSLContextFactory();
}
sslContextFactoryObject.init(sslConfig.getProperties());
} catch (Exception e) {
throw new RuntimeException(e);
}
sslContextFactory = sslContextFactoryObject;
}
public SocketChannelWrapper wrapSocketChannel(SocketChannel socketChannel, boolean client) throws Exception {
return new SSLSocketChannelWrapper(sslContextFactory.getSSLContext(), socketChannel, client);
}
}
public IOService getIOHandler() {
return ioService;
}
public MemberSocketInterceptor getMemberSocketInterceptor() {
return memberSocketInterceptor;
}
public void addConnectionListener(ConnectionListener listener) {
connectionListeners.add(listener);
}
public boolean bind(TcpIpConnection connection, Address remoteEndPoint, Address localEndpoint, final boolean replyBack) {
if (logger.isFinestEnabled()) {
log(Level.FINEST, "Binding " + connection + " to " + remoteEndPoint + ", replyBack is " + replyBack);
}
final Address thisAddress = ioService.getThisAddress();
if (!connection.isClient() && !thisAddress.equals(localEndpoint)) {
log(Level.WARNING, "Wrong bind request from " + remoteEndPoint + "! This node is not requested endpoint: " + localEndpoint);
connection.close();
return false;
}
connection.setEndPoint(remoteEndPoint);
if (replyBack) {
sendBindRequest(connection, remoteEndPoint, false);
}
final Connection existingConnection = connectionsMap.get(remoteEndPoint);
if (existingConnection != null && existingConnection.live()) {
if (existingConnection != connection) {
if (logger.isFinestEnabled()) {
log(Level.FINEST, existingConnection + " is already bound to " + remoteEndPoint + ", new one is " + connection);
}
activeConnections.add(connection);
}
return false;
}
if (registerConnectionEndpoint(connection, remoteEndPoint, thisAddress)) {
return true;
}
return false;
}
private boolean registerConnectionEndpoint(final TcpIpConnection connection,
final Address remoteEndPoint, Address thisAddress) {
if (!remoteEndPoint.equals(thisAddress)) {
if (!connection.isClient()) {
connection.setMonitor(getConnectionMonitor(remoteEndPoint, true));
}
connectionsMap.put(remoteEndPoint, connection);
connectionsInProgress.remove(remoteEndPoint);
ioService.getEventService().executeEventCallback(new StripedRunnable() {
@Override
public void run() {
for (ConnectionListener listener : connectionListeners) {
listener.connectionAdded(connection);
}
}
@Override
public int getKey() {
return remoteEndPoint.hashCode();
}
});
return true;
}
return false;
}
void sendBindRequest(final TcpIpConnection connection, final Address remoteEndPoint, final boolean replyBack) {
connection.setEndPoint(remoteEndPoint);
//make sure bind packet is the first packet sent to the end point.
final BindOperation bind = new BindOperation(ioService.getThisAddress(), remoteEndPoint, replyBack);
final Data bindData = ioService.toData(bind);
final Packet packet = new Packet(bindData, serializationContext);
packet.setHeader(Packet.HEADER_OP);
connection.write(packet);
//now you can send anything...
}
private int nextSelectorIndex() {
return Math.abs(nextSelectorIndex.getAndIncrement()) % selectorThreadCount;
}
SocketChannelWrapper wrapSocketChannel(SocketChannel socketChannel, boolean client) throws Exception {
final SocketChannelWrapper socketChannelWrapper = socketChannelWrapperFactory.wrapSocketChannel(socketChannel, client);
acceptedSockets.add(socketChannelWrapper);
return socketChannelWrapper;
}
TcpIpConnection assignSocketChannel(SocketChannelWrapper channel) {
final int index = nextSelectorIndex();
final TcpIpConnection connection = new TcpIpConnection(this, inSelectors[index], outSelectors[index], connectionIdGen.incrementAndGet(), channel);
activeConnections.add(connection);
acceptedSockets.remove(channel);
connection.getReadHandler().register();
log(Level.INFO, channel.socket().getLocalPort() + " accepted socket connection from " + channel.socket().getRemoteSocketAddress());
return connection;
}
void failedConnection(Address address, Throwable t, boolean silent) {
connectionsInProgress.remove(address);
ioService.onFailedConnection(address);
if (!silent) {
getConnectionMonitor(address, false).onError(t);
}
}
public Connection getConnection(Address address) {
return connectionsMap.get(address);
}
public Connection getOrConnect(Address address) {
return getOrConnect(address, false);
}
public Connection getOrConnect(final Address address, final boolean silent) {
Connection connection = connectionsMap.get(address);
if (connection == null && live) {
if (connectionsInProgress.add(address)) {
ioService.shouldConnectTo(address);
ioService.executeAsync(new SocketConnector(this, address, silent));
}
}
return connection;
}
private final ConstructorFunction<Address, ConnectionMonitor> monitorConstructor
= new ConstructorFunction<Address, ConnectionMonitor>() {
public ConnectionMonitor createNew(Address endpoint) {
return new ConnectionMonitor(TcpIpConnectionManager.this, endpoint);
}
};
private ConnectionMonitor getConnectionMonitor(Address endpoint, boolean reset) {
final ConnectionMonitor monitor = ConcurrencyUtil.getOrPutIfAbsent(monitors, endpoint, monitorConstructor);
if (reset) {
monitor.reset();
}
return monitor;
}
@Override
public void destroyConnection(final Connection connection) {
if (connection == null) {
return;
}
if (logger.isFinestEnabled()) {
log(Level.FINEST, "Destroying " + connection);
}
activeConnections.remove(connection);
final Address endPoint = connection.getEndPoint();
if (endPoint != null) {
connectionsInProgress.remove(endPoint);
final Connection existingConn = connectionsMap.get(endPoint);
if (existingConn == connection && live) {
connectionsMap.remove(endPoint);
ioService.getEventService().executeEventCallback(new StripedRunnable() {
@Override
public void run() {
for (ConnectionListener listener : connectionListeners) {
listener.connectionRemoved(connection);
}
}
@Override
public int getKey() {
return endPoint.hashCode();
}
});
}
}
if (connection.live()) {
connection.close();
}
}
protected void initSocket(Socket socket) throws Exception {
if (socketLingerSeconds > 0) {
socket.setSoLinger(true, socketLingerSeconds);
}
socket.setKeepAlive(socketKeepAlive);
socket.setTcpNoDelay(socketNoDelay);
socket.setReceiveBufferSize(socketReceiveBufferSize);
socket.setSendBufferSize(socketSendBufferSize);
}
public synchronized void start() {
if (live) return;
live = true;
log(Level.FINEST, "Starting ConnectionManager and IO selectors.");
for (int i = 0; i < inSelectors.length; i++) {
inSelectors[i] = new InSelectorImpl(ioService, i);
outSelectors[i] = new OutSelectorImpl(ioService, i);
inSelectors[i].start();
outSelectors[i].start();
}
if (serverSocketChannel != null) {
if (socketAcceptorThread != null) {
logger.warning("SocketAcceptor thread is already live! Shutting down old acceptor...");
shutdownSocketAcceptor();
}
Runnable acceptRunnable = new SocketAcceptor(serverSocketChannel, this);
socketAcceptorThread = new Thread(ioService.getThreadGroup(), acceptRunnable,
ioService.getThreadPrefix() + "Acceptor");
socketAcceptorThread.start();
}
}
public synchronized void restart() {
stop();
start();
}
public synchronized void shutdown() {
try {
if (live) {
stop();
connectionListeners.clear();
}
} finally {
if (serverSocketChannel != null) {
try {
if (logger.isFinestEnabled()) {
log(Level.FINEST, "Closing server socket channel: " + serverSocketChannel);
}
serverSocketChannel.close();
} catch (IOException ignore) {
logger.finest(ignore);
}
}
}
}
private void stop() {
live = false;
log(Level.FINEST, "Stopping ConnectionManager");
shutdownSocketAcceptor(); // interrupt acceptor thread after live=false
for (SocketChannelWrapper socketChannel : acceptedSockets) {
IOUtil.closeResource(socketChannel);
}
for (Connection conn : connectionsMap.values()) {
try {
destroyConnection(conn);
} catch (final Throwable ignore) {
logger.finest(ignore);
}
}
for (TcpIpConnection conn : activeConnections) {
try {
destroyConnection(conn);
} catch (final Throwable ignore) {
logger.finest(ignore);
}
}
shutdownIOSelectors();
connectionsInProgress.clear();
connectionsMap.clear();
monitors.clear();
activeConnections.clear();
}
private synchronized void shutdownIOSelectors() {
if (logger.isFinestEnabled()) {
log(Level.FINEST, "Shutting down IO selectors... Total: " + selectorThreadCount);
}
for (int i = 0; i < selectorThreadCount; i++) {
IOSelector ioSelector = inSelectors[i];
if (ioSelector != null) {
ioSelector.shutdown();
}
inSelectors[i] = null;
ioSelector = outSelectors[i];
if (ioSelector != null) {
ioSelector.shutdown();
}
outSelectors[i] = null;
}
}
private void shutdownSocketAcceptor() {
log(Level.FINEST, "Shutting down SocketAcceptor thread.");
Thread killingThread = socketAcceptorThread;
if (killingThread == null) {
return;
}
socketAcceptorThread = null;
killingThread.interrupt();
try {
killingThread.join(1000 * 10);
} catch (InterruptedException e) {
logger.finest(e);
}
}
public int getCurrentClientConnections() {
int count = 0;
for (TcpIpConnection conn : activeConnections) {
if (conn.live()) {
if (conn.isClient()) {
count++;
}
}
}
return count;
}
public boolean isLive() {
return live;
}
public Map<Address, Connection> getReadonlyConnectionMap() {
return Collections.unmodifiableMap(connectionsMap);
}
private void log(Level level, String message) {
logger.log(level, message);
ioService.getSystemLogService().logConnection(message);
}
boolean useAnyOutboundPort() {
return outboundPortCount == 0;
}
int getOutboundPortCount() {
return outboundPortCount;
}
int acquireOutboundPort() {
if (useAnyOutboundPort()) {
return 0;
}
synchronized (outboundPorts) {
final Integer port = outboundPorts.removeFirst();
outboundPorts.addLast(port);
return port;
}
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("Connections {");
for (Connection conn : connectionsMap.values()) {
sb.append("\n");
sb.append(conn);
}
sb.append("\nlive=");
sb.append(live);
sb.append("\n}");
return sb.toString();
}
}