Package com.hazelcast.client.connection.nio

Source Code of com.hazelcast.client.connection.nio.ClientConnectionManagerImpl$SSLSocketChannelWrapperFactory

/*
* Copyright (c) 2008-2013, Hazelcast, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hazelcast.client.connection.nio;

import com.hazelcast.client.AuthenticationException;
import com.hazelcast.client.AuthenticationRequest;
import com.hazelcast.client.ClientPrincipal;
import com.hazelcast.client.ClientRequest;
import com.hazelcast.client.ClientResponse;
import com.hazelcast.client.HazelcastClient;
import com.hazelcast.client.LoadBalancer;
import com.hazelcast.client.config.ClientConfig;
import com.hazelcast.client.config.ClientNetworkConfig;
import com.hazelcast.client.config.ClientSecurityConfig;
import com.hazelcast.client.config.SocketOptions;
import com.hazelcast.client.connection.Authenticator;
import com.hazelcast.client.connection.ClientConnectionManager;
import com.hazelcast.client.connection.Router;
import com.hazelcast.client.spi.ClientClusterService;
import com.hazelcast.client.spi.EventHandler;
import com.hazelcast.client.spi.impl.ClientCallFuture;
import com.hazelcast.client.spi.impl.ClientExecutionServiceImpl;
import com.hazelcast.config.GroupConfig;
import com.hazelcast.config.SSLConfig;
import com.hazelcast.config.SocketInterceptorConfig;
import com.hazelcast.core.HazelcastException;
import com.hazelcast.core.ICompletableFuture;
import com.hazelcast.core.Member;
import com.hazelcast.logging.ILogger;
import com.hazelcast.logging.Logger;
import com.hazelcast.nio.Address;
import com.hazelcast.nio.ClassLoaderUtil;
import com.hazelcast.nio.ClientPacket;
import com.hazelcast.nio.DefaultSocketChannelWrapper;
import com.hazelcast.nio.IOSelector;
import com.hazelcast.nio.SocketChannelWrapper;
import com.hazelcast.nio.SocketInterceptor;
import com.hazelcast.nio.serialization.Data;
import com.hazelcast.nio.serialization.SerializationContext;
import com.hazelcast.nio.serialization.SerializationService;
import com.hazelcast.nio.ssl.BasicSSLContextFactory;
import com.hazelcast.nio.ssl.SSLContextFactory;
import com.hazelcast.nio.ssl.SSLSocketChannelWrapper;
import com.hazelcast.security.Credentials;
import com.hazelcast.security.UsernamePasswordCredentials;
import com.hazelcast.spi.exception.RetryableIOException;
import com.hazelcast.spi.impl.SerializableCollection;
import com.hazelcast.util.ExceptionUtil;

import java.io.IOException;
import java.net.Socket;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

public class ClientConnectionManagerImpl implements ClientConnectionManager {

    private static final ILogger logger = Logger.getLogger(ClientConnectionManagerImpl.class);

    private int RETRY_COUNT = 20;
    private final ConcurrentMap<Address, Object> connectionLockMap = new ConcurrentHashMap<Address, Object>();

    static final int KILO_BYTE = 1024;
    public static final int BUFFER_SIZE = 16 << 10; // 32k


    private final AtomicInteger connectionIdGen = new AtomicInteger();
    private final HazelcastClient client;
    private final Router router;
    private final SocketInterceptor socketInterceptor;
    private final SocketOptions socketOptions;
    private final IOSelector inSelector;
    private final IOSelector outSelector;
    private final boolean smartRouting;
    private volatile ClientConnection ownerConnection = null;
    private final Object ownerConnectionLock = new Object();

    private final Credentials credentials;
    private volatile ClientPrincipal principal;
    private final AtomicInteger callIdIncrementer = new AtomicInteger();
    private final SocketChannelWrapperFactory socketChannelWrapperFactory;
    private final ClientExecutionServiceImpl executionService;
    private final ConcurrentMap<Address, ClientConnection> connections
            = new ConcurrentHashMap<Address, ClientConnection>();

    private volatile boolean live = false;

    public ClientConnectionManagerImpl(HazelcastClient client, LoadBalancer loadBalancer) {
        this.client = client;
        final ClientConfig config = client.getClientConfig();
        final ClientNetworkConfig networkConfig = config.getNetworkConfig();
        final GroupConfig groupConfig = config.getGroupConfig();
        final ClientSecurityConfig securityConfig = config.getSecurityConfig();
        Credentials c = securityConfig.getCredentials();
        if (c == null) {
            final String credentialsClassname = securityConfig.getCredentialsClassname();
            //todo: Should be moved to a reflection utility.
            if (credentialsClassname != null) {
                try {
                    c = ClassLoaderUtil.newInstance(config.getClassLoader(), credentialsClassname);
                } catch (Exception e) {
                    throw ExceptionUtil.rethrow(e);
                }
            }
        }
        if (c == null) {
            c = new UsernamePasswordCredentials(groupConfig.getName(), groupConfig.getPassword());
        }

        this.smartRouting = networkConfig.isSmartRouting();
        this.executionService = (ClientExecutionServiceImpl) client.getClientExecutionService();
        this.credentials = c;
        router = new Router(loadBalancer);
        inSelector = new ClientInSelectorImpl(client.getThreadGroup());
        outSelector = new ClientOutSelectorImpl(client.getThreadGroup());
        //init socketInterceptor
        SocketInterceptorConfig sic = networkConfig.getSocketInterceptorConfig();
        SocketInterceptor implementation = null;
        if (sic != null && sic.isEnabled()) {
            implementation = (SocketInterceptor) sic.getImplementation();
            if (implementation == null && sic.getClassName() != null) {
                try {
                    implementation = (SocketInterceptor) Class.forName(sic.getClassName()).newInstance();
                } catch (Throwable e) {
                    logger.severe("SocketInterceptor class cannot be instantiated!" + sic.getClassName(), e);
                }
            }
        }

        socketInterceptor = implementation;
        if (socketInterceptor != null) {
            logger.info("SocketInterceptor is enabled");
            socketInterceptor.init(sic.getProperties());
        }

        socketOptions = networkConfig.getSocketOptions();

        SSLConfig sslConfig = networkConfig.getSSLConfig(); //ioService.getSSLConfig(); TODO
        if (sslConfig != null && sslConfig.isEnabled()) {
            socketChannelWrapperFactory = new SSLSocketChannelWrapperFactory(sslConfig);
            logger.info("SSL is enabled");
        } else {
            socketChannelWrapperFactory = new DefaultSocketChannelWrapperFactory();
        }

    }

    public boolean isLive() {
        return live;
    }

    public SerializationContext getSerializationContext() {
        return client.getSerializationService().getSerializationContext();
    }

    public SerializationService getSerializationService() {
        return client.getSerializationService();
    }

    @Override
    public synchronized void start() {
        if (live) {
            return;
        }
        live = true;
        inSelector.start();
        outSelector.start();
    }

    @Override
    public synchronized void shutdown() {
        if (!live) {
            return;
        }
        live = false;
        for (ClientConnection connection : connections.values()) {
            connection.close();
        }
        inSelector.shutdown();
        outSelector.shutdown();
        connectionLockMap.clear();
    }


    public void markOwnerAddressAsClosed() {
        synchronized (ownerConnectionLock) {
            ownerConnection = null;
        }
    }

    private Address waitForOwnerConnection() throws RetryableIOException {
        if (ownerConnection != null) {
            return ownerConnection.getRemoteEndpoint();
        }

        synchronized (ownerConnectionLock) {
            ClientNetworkConfig networkConfig = client.getClientConfig().getNetworkConfig();
            int connectionAttemptLimit = networkConfig.getConnectionAttemptLimit();
            int connectionAttemptPeriod = networkConfig.getConnectionAttemptPeriod();
            int waitTime = connectionAttemptLimit * connectionAttemptPeriod * 2;

            while (ownerConnection == null) {
                try {
                    ownerConnectionLock.wait(waitTime);
                } catch (InterruptedException e) {
                    logger.warning("Wait for owner connection is timed out");
                    throw new RetryableIOException(e);
                }
            }
            return ownerConnection.getRemoteEndpoint();
        }
    }

    @Override
    public ClientConnection ownerConnection(Address address) throws Exception {
        final ManagerAuthenticator authenticator = new ManagerAuthenticator();
        final ConnectionProcessor connectionProcessor = new ConnectionProcessor(address, authenticator, true);
        ICompletableFuture<ClientConnection> future = executionService.submitInternal(connectionProcessor);
        try {
            ownerConnection = future.get(5, TimeUnit.SECONDS);
            synchronized (ownerConnectionLock) {
                ownerConnectionLock.notifyAll();
            }
            return ownerConnection;
        } catch (Exception e) {
            future.cancel(true);
            throw new RetryableIOException(e);
        }
    }

    @Override
    public ClientConnection tryToConnect(Address target) throws Exception {
        Authenticator authenticator = new ClusterAuthenticator();
        int count = 0;
        IOException lastError = null;
        while (count < RETRY_COUNT) {
            try {
                if (target == null || !isMember(target)) {
                    Address address = getAddressFromLoadBalancer();
                    return getOrConnect(address, authenticator);
                } else {
                    return getOrConnect(target, authenticator);
                }
            } catch (IOException e) {
                lastError = e;
            }
            target = null;
            count++;
        }
        throw lastError;
    }

    private Address getAddressFromLoadBalancer() {
        Address address = router.next();
        if (address == null) {
            Set<Member> members = client.getCluster().getMembers();
            String msg;
            if (members.isEmpty()) {
                msg = "No address was return by the LoadBalancer since there are no members in the cluster";
            } else {
                msg = "No address was return by the LoadBalancer. " +
                        "But the cluster contains the following members:" + members;
            }
            throw new IllegalStateException(msg);
        }

        return address;
    }

    public ClientPrincipal getPrincipal() {
        return principal;
    }

    private boolean isMember(Address target) {
        final ClientClusterService clientClusterService = client.getClientClusterService();
        return clientClusterService.getMember(target) != null;
    }

    private ClientConnection getOrConnect(Address address, Authenticator authenticator) throws Exception {
        if (address == null) {
            throw new NullPointerException("Address is required!");
        }
        if (!smartRouting) {
            address = waitForOwnerConnection();
        }
        ClientConnection clientConnection = connections.get(address);
        if (clientConnection == null) {
            final Object lock = getLock(address);
            synchronized (lock) {
                clientConnection = connections.get(address);
                if (clientConnection == null) {
                    final ConnectionProcessor connectionProcessor = new ConnectionProcessor(address, authenticator, false);
                    final ICompletableFuture<ClientConnection> future = executionService.submitInternal(connectionProcessor);
                    try {
                        clientConnection = future.get(5, TimeUnit.SECONDS);
                    } catch (Exception e) {
                        future.cancel(true);
                        throw new RetryableIOException(e);
                    }
                    ClientConnection current = connections.putIfAbsent(clientConnection.getRemoteEndpoint(), clientConnection);
                    if (current != null) {
                        clientConnection.innerClose();
                        clientConnection = current;
                    }
                }
            }
        }
        return clientConnection;
    }

    private class ConnectionProcessor implements Callable<ClientConnection> {

        final Address address;
        final Authenticator authenticator;
        final boolean isBlock;

        private ConnectionProcessor(final Address address, final Authenticator authenticator, final boolean isBlock) {
            this.address = address;
            this.authenticator = authenticator;
            this.isBlock = isBlock;
        }

        @Override
        public ClientConnection call() throws Exception {
            if (!live) {
                throw new HazelcastException("ConnectionManager is not active!!!");
            }
            SocketChannel socketChannel = null;
            try {
                socketChannel = SocketChannel.open();
                Socket socket = socketChannel.socket();
                socket.setKeepAlive(socketOptions.isKeepAlive());
                socket.setTcpNoDelay(socketOptions.isTcpNoDelay());
                socket.setReuseAddress(socketOptions.isReuseAddress());
                if (socketOptions.getLingerSeconds() > 0) {
                    socket.setSoLinger(true, socketOptions.getLingerSeconds());
                }
                int bufferSize = socketOptions.getBufferSize() * KILO_BYTE;
                if (bufferSize < 0) {
                    bufferSize = BUFFER_SIZE;
                }
                socket.setSendBufferSize(bufferSize);
                socket.setReceiveBufferSize(bufferSize);
                socketChannel.socket().connect(address.getInetSocketAddress(), 5000);
                SocketChannelWrapper socketChannelWrapper = socketChannelWrapperFactory.wrapSocketChannel(socketChannel, true);
                final ClientConnection clientConnection = new ClientConnection(ClientConnectionManagerImpl.this, inSelector,
                        outSelector, connectionIdGen.incrementAndGet(), socketChannelWrapper, executionService);
                socketChannel.configureBlocking(true);
                if (socketInterceptor != null) {
                    socketInterceptor.onConnect(socket);
                }
                authenticator.auth(clientConnection);
                socketChannel.configureBlocking(isBlock);
                socket.setSoTimeout(0);
                if (!isBlock) {
                    clientConnection.getReadHandler().register();
                }
                return clientConnection;
            } catch (Exception e) {
                if (socketChannel != null) {
                    socketChannel.close();
                }
                throw ExceptionUtil.rethrow(e);
            }
        }
    }

    public void destroyConnection(ClientConnection clientConnection) {
        Address endpoint = clientConnection.getRemoteEndpoint();
        if (endpoint != null) {
            connections.remove(clientConnection.getRemoteEndpoint());
            closeIfOwnerConnection(endpoint);
        }
    }

    private void closeIfOwnerConnection(Address endpoint) {
        final ClientConnection currentOwnerConnection = ownerConnection;
        if (currentOwnerConnection == null || !currentOwnerConnection.live()) {
            return;
        }
        if (endpoint.equals(currentOwnerConnection.getRemoteEndpoint())) {
            try {
                currentOwnerConnection.close();
            } catch (Exception ignored) {
            }
        }
    }

    @Override
    public boolean removeEventHandler(Integer callId) {
        if (callId != null) {
            for (ClientConnection clientConnection : connections.values()) {
                if (clientConnection.deRegisterEventHandler(callId) != null) {
                    return true;
                }
            }
        }
        return false;
    }

    public void handlePacket(ClientPacket packet) {
        final ClientConnection conn = (ClientConnection) packet.getConn();
        conn.incrementPacketCount();
        executionService.executeInternal(new ClientPacketProcessor(packet));
    }

    public int newCallId() {
        return callIdIncrementer.incrementAndGet();
    }

    private class ClientEventProcessor implements Runnable {
        final int callId;
        final ClientConnection conn;
        final Data response;

        private ClientEventProcessor(int callId, ClientConnection conn, Data response) {
            this.callId = callId;
            this.conn = conn;
            this.response = response;
        }

        @Override
        public void run() {
            handleEvent(response, callId, conn);
        }

        private void handleEvent(Data event, int callId, ClientConnection conn) {
            final EventHandler eventHandler = conn.getEventHandler(callId);
            final Object eventObject = getSerializationService().toObject(event);
            if (eventHandler == null) {
                logger.warning("No eventHandler for callId: " + callId + ", event: " + eventObject + ", conn: " + conn);
                return;
            }
            eventHandler.handle(eventObject);
        }
    }

    private class ClientPacketProcessor implements Runnable {

        final ClientPacket packet;

        ClientPacketProcessor(ClientPacket packet) {
            this.packet = packet;
        }

        @Override
        public void run() {
            final ClientConnection conn = (ClientConnection) packet.getConn();
            final ClientResponse clientResponse = getSerializationService().toObject(packet.getData());
            final int callId = clientResponse.getCallId();
            final Data response = clientResponse.getResponse();
            if (clientResponse.isEvent()) {
                executionService.execute(new ClientEventProcessor(callId, conn, response));
            } else {
                handlePacket(response, clientResponse.isError(), callId, conn);
            }
            conn.decrementPacketCount();
        }

        private void handlePacket(Object response, boolean isError, int callId, ClientConnection conn) {
            final ClientCallFuture future = conn.deRegisterCallId(callId);
            if (future == null) {
                logger.warning("No call for callId: " + callId + ", response: " + response);
                return;
            }
            if (isError) {
                response = getSerializationService().toObject(response);
            }
            future.notify(response);
        }
    }


    public class ManagerAuthenticator implements Authenticator {

        @Override
        public void auth(ClientConnection connection) throws AuthenticationException, IOException {
            final Object response = authenticate(connection, credentials, principal, true, true);
            principal = (ClientPrincipal) response;
        }
    }

    private class ClusterAuthenticator implements Authenticator {
        @Override
        public void auth(ClientConnection connection) throws AuthenticationException, IOException {
            authenticate(connection, credentials, principal, false, false);
        }
    }

    private Object authenticate(ClientConnection connection, Credentials credentials, ClientPrincipal principal,
                                boolean reAuth, boolean firstConnection) throws IOException {
        final SerializationService ss = getSerializationService();
        AuthenticationRequest auth = new AuthenticationRequest(credentials, principal);
        connection.init();
        auth.setReAuth(reAuth);
        auth.setFirstConnection(firstConnection);
        SerializableCollection collectionWrapper; //contains remoteAddress and principal
        try {
            collectionWrapper = (SerializableCollection) sendAndReceive(auth, connection);
        } catch (Exception e) {
            throw new RetryableIOException(e);
        }
        final Iterator<Data> iter = collectionWrapper.iterator();
        if (iter.hasNext()) {
            final Data addressData = iter.next();
            final Address address = ss.toObject(addressData);
            connection.setRemoteEndpoint(address);
            if (iter.hasNext()) {
                final Data principalData = iter.next();
                return ss.toObject(principalData);
            }
        }
        throw new AuthenticationException();
    }

    public Object sendAndReceive(ClientRequest request, ClientConnection connection) throws Exception {
        final SerializationService ss = client.getSerializationService();
        connection.write(ss.toData(request));
        final Data data = connection.read();
        ClientResponse clientResponse = ss.toObject(data);
        Object response = ss.toObject(clientResponse.getResponse());
        if (response instanceof Throwable) {
            Throwable t = (Throwable) response;
            ExceptionUtil.fixRemoteStackTrace(t, Thread.currentThread().getStackTrace());
            throw new Exception(t);
        }
        return response;
    }

    interface SocketChannelWrapperFactory {
        SocketChannelWrapper wrapSocketChannel(SocketChannel socketChannel, boolean client) throws Exception;
    }

    static class DefaultSocketChannelWrapperFactory implements SocketChannelWrapperFactory {
        @Override
        public SocketChannelWrapper wrapSocketChannel(SocketChannel socketChannel, boolean client) throws Exception {
            return new DefaultSocketChannelWrapper(socketChannel);
        }
    }

    static class SSLSocketChannelWrapperFactory implements SocketChannelWrapperFactory {
        final SSLContextFactory sslContextFactory;

        SSLSocketChannelWrapperFactory(SSLConfig sslConfig) {
//            if (CipherHelper.isSymmetricEncryptionEnabled(ioService)) {
//                throw new RuntimeException("SSL and SymmetricEncryption cannot be both enabled!");
//            }
            SSLContextFactory sslContextFactoryObject = (SSLContextFactory) sslConfig.getFactoryImplementation();
            try {
                String factoryClassName = sslConfig.getFactoryClassName();
                if (sslContextFactoryObject == null && factoryClassName != null) {
                    sslContextFactoryObject = (SSLContextFactory) Class.forName(factoryClassName).newInstance();
                }
                if (sslContextFactoryObject == null) {
                    sslContextFactoryObject = new BasicSSLContextFactory();
                }
                sslContextFactoryObject.init(sslConfig.getProperties());
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
            sslContextFactory = sslContextFactoryObject;
        }

        @Override
        public SocketChannelWrapper wrapSocketChannel(SocketChannel socketChannel, boolean client) throws Exception {
            return new SSLSocketChannelWrapper(sslContextFactory.getSSLContext(), socketChannel, client);
        }
    }

    private Object getLock(Address address) {
        Object lock = connectionLockMap.get(address);
        if (lock == null) {
            lock = new Object();
            Object current = connectionLockMap.putIfAbsent(address, lock);
            if (current != null) {
                lock = current;
            }
        }
        return lock;
    }

}
TOP

Related Classes of com.hazelcast.client.connection.nio.ClientConnectionManagerImpl$SSLSocketChannelWrapperFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.