Package org.apache.openejb.server.discovery

Source Code of org.apache.openejb.server.discovery.MultipointServer$Session

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
*  Unless required by applicable law or agreed to in writing, software
*  distributed under the License is distributed on an "AS IS" BASIS,
*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*  See the License for the specific language governing permissions and
*  limitations under the License.
*/
package org.apache.openejb.server.discovery;

import org.apache.openejb.util.LogCategory;
import org.apache.openejb.util.Logger;

import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* @version $Rev$ $Date$
*/
public class MultipointServer {
    private static final Logger log = Logger.getInstance(LogCategory.OPENEJB_SERVER.createChild("discovery"), MultipointServer.class);

    private final String host;
    private final int port;
    private final Selector selector;
    private final URI me;

    private final Tracker tracker;

    private final LinkedList<URI> connect = new LinkedList<URI>();
    private final Map<URI, Session> connections = new HashMap<URI, Session>();

    public MultipointServer(int port, Tracker tracker) throws IOException {
        this("localhost", port, tracker);
    }

    public MultipointServer(String host, int port, Tracker tracker) throws IOException {
        if (tracker == null) throw new NullPointerException("tracker cannot be null");
        this.host = host;
        this.port = port;
        this.tracker = tracker;
        me = URI.create("conn://" + host + ":" + port);

        ServerSocketChannel serverChannel = ServerSocketChannel.open();

        ServerSocket serverSocket = serverChannel.socket();
        InetSocketAddress address = new InetSocketAddress(host,port);
        serverSocket.bind(address);
        serverChannel.configureBlocking(false);

        selector = Selector.open();

        serverChannel.register(selector, SelectionKey.OP_ACCEPT);

        println("Listening");
    }

    public MultipointServer start() {
        if (running.compareAndSet(false, true)) {
            Thread thread = new Thread(new Runnable() {
                public void run() {
                    _run();
                }
            });
            thread.setName("Server." + port);
            thread.start();
        }
        return this;
    }

    public void stop() {
        running.set(false);
    }

    public class Session {

        private static final int EOF = 3;

        private final SocketChannel channel;
        private final ByteBuffer read = ByteBuffer.allocate(1024);
        private final SelectionKey key;
        private final List<URI> listed = new ArrayList<URI>();

        private ByteBuffer write;
        private State state = State.OPEN;
        private URI uri;
        public boolean hangup;
        private final boolean client;

        public Session(SocketChannel channel, InetSocketAddress address, URI uri) throws ClosedChannelException {
            this.channel = channel;
            this.client = uri != null;
            this.uri = uri != null ? uri : URI.create("conn://" + address.getHostName() + ":" + address.getPort());
            this.key = channel.register(selector, 0, this);
        }

        public Session ops(int ops) {
            key.interestOps(ops);
            return this;
        }

        public void state(int ops, State state) {
            this.state = state;
            if (ops > 0) key.interestOps(ops);
        }

        public void setURI(URI uri) {
            this.uri = uri;
        }

        private void trace(String str) {
//            println(message(str));

            if (log.isDebugEnabled()) {
                log.debug(message(str));
            }
        }

        private String message(String str) {
            StringBuilder sb = new StringBuilder();
            sb.append(port);
            sb.append(" ");
            if (key.isValid()) {
                if ((key.interestOps() & SelectionKey.OP_READ) == SelectionKey.OP_READ) sb.append("<");
                if ((key.interestOps() & SelectionKey.OP_WRITE) == SelectionKey.OP_WRITE) sb.append(">");
                if ((key.interestOps() == 0)) sb.append("-");
            } else {
                sb.append(":");
            }
            sb.append(" ");
            sb.append(uri.getPort());
            sb.append(" ");
            sb.append(this.state);
            sb.append(" ");
            sb.append(str);
            String x = sb.toString();
            return x;
        }

        public void write(URI uri) throws IOException {
            write(Arrays.asList(uri));
        }

        public void write(Collection<?> uris) throws IOException {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();

            for (Object uri : uris) {
                String s = uri.toString();
                byte[] b = s.getBytes("UTF-8");
                baos.write(b);
                baos.write(EOF);
            }

            this.write = ByteBuffer.wrap(baos.toByteArray());
        }

        public boolean drain() throws IOException {
            this.channel.write(write);
            return write.remaining() == 0;
        }

        public String read() throws IOException {

            if (channel.read(read) == -1) throw new EOFException();

            byte[] buf = read.array();

            int end = endOfText(buf, 0, read.position());

            if (end < 0) return null;

            // Copy the string without the terminator char
            String text = new String(buf, 0, end, "UTF-8");

            int newPos = read.position() - end;
            System.arraycopy(buf, end + 1, buf, 0, newPos - 1);
            read.position(newPos - 1);

            return text;
        }

        private int endOfText(byte[] data, int offset, int pos) {
            for (int i = offset; i < pos; i++) if (data[i] == EOF) return i;
            return -1;
        }

        @Override
        public String toString() {
            return "Session{" +
                    "uri=" + uri +
                    ", state=" + state +
                    ", owner=" + port +
                    ", s=" + (client ? channel.socket().getPort() : channel.socket().getLocalPort()) +
                    ", c=" + (!client ? channel.socket().getPort() : channel.socket().getLocalPort()) +
                    ", " + (client ? "client" : "server") +
                    '}';
        }

        private long last = 0;

        public void tick() throws IOException {
            if (state != State.HEARTBEAT) return;

            long now = System.currentTimeMillis();
            long delay = now - last;

            if (delay > tracker.getHeartRate()) {
                last = now;
                heartbeat();
            }

        }

        private void heartbeat() throws IOException {
            write(tracker.getRegisteredServices());
            state(SelectionKey.OP_READ | SelectionKey.OP_WRITE, State.HEARTBEAT);

            tracker.checkServices();
        }
    }

    private static enum State {
        OPEN, GREETING, LISTING, HEARTBEAT, CLOSED
    }

    private final AtomicBoolean running = new AtomicBoolean();

    private void _run() {
        while (running.get()) {
            try {
                selector.select(1000);
            } catch (IOException ex) {
                ex.printStackTrace();
                break;
            }

            Set keys = selector.selectedKeys();

            Iterator iterator = keys.iterator();
            while (iterator.hasNext()) {
                SelectionKey key = (SelectionKey) iterator.next();
                iterator.remove();

                try {
                    if (key.isAcceptable()) {

                        // we are a server

                        // when you are a server, we must first listen for the
                        // address of the client before sending data.

                        // once they send us their address, we will send our
                        // full list of known addresses, followed by our own
                        // address to signal that we are done.

                        // Afterward we will only pulls our heartbeat

                        ServerSocketChannel server = (ServerSocketChannel) key.channel();
                        SocketChannel client = server.accept();
                        InetSocketAddress address = (InetSocketAddress) client.socket().getRemoteSocketAddress();

                        client.configureBlocking(false);

                        Session session = new Session(client, address, null);
                        session.trace("accept");
                        session.state(java.nio.channels.SelectionKey.OP_READ, State.GREETING);
                    }

                    if (key.isConnectable()) {

                        // we are a client

                        Session session = (Session) key.attachment();
                        session.channel.finishConnect();
                        session.trace("connected");

                        // when you are a client, first say high to everyone
                        // before accepting data

                        // once a server reads our address, it will send it's
                        // full list of known addresses, followed by it's own
                        // address to signal that it is done.

                        // we will initiate connections to everyone in the list
                        // who we have not yet seen.

                        // Afterward the server will only pulls its heartbeat

                        session.write(me);

                        session.state(java.nio.channels.SelectionKey.OP_WRITE, State.GREETING);
                    }

                    if (key.isReadable()) {


                        Session session = (Session) key.attachment();

                        switch (session.state) {
                            case GREETING: { // read

                                String message = session.read();

                                if (message == null) break; // need to read more

                                session.setURI(URI.create(message));

                                connected(session);

                                session.trace("welcome");

                                ArrayList<URI> list = connections();

                                // When they read themselves on the list
                                // they'll know it's time to list their URIs

                                list.remove(me); // yank
                                list.remove(session.uri); // yank
                                list.add(session.uri); // add to the end

                                session.write(list);

                                session.state(java.nio.channels.SelectionKey.OP_WRITE, State.LISTING);

                                session.trace("STARTING");


                            }
                            break;

                            case LISTING: { // read

                                String message = null;

                                while ((message = session.read()) != null) {

                                    URI uri = URI.create(message);

                                    session.listed.add(uri);

                                    session.trace(message);

                                    // they listed me, means they want my list
                                    if (uri.equals(me)) {
                                        ArrayList<URI> list = connections();

                                        for (URI reported : session.listed) {
                                            list.remove(reported);
                                        }

                                        // When they read us on the list
                                        // they'll know it's time to switch to heartbeat

                                        list.remove(session.uri);
                                        list.remove(me); // yank if in the middle
                                        list.add(me); // add to the end

                                        session.write(list);

                                        session.state(java.nio.channels.SelectionKey.OP_WRITE, State.LISTING);

                                    } else if (uri.equals(session.uri)) {

                                        if (session.hangup) {
                                            session.state(0, State.CLOSED);
                                            session.trace("hangup");
                                            hangup(key);
                                        } else {
                                            session.state(java.nio.channels.SelectionKey.OP_READ, State.HEARTBEAT);
                                        }

                                    } else {
                                        try {
                                            connect(uri);
                                        } catch (Exception e) {
                                            println("connect failed " + uri + " - " + e.getMessage());
                                            e.printStackTrace();
                                        }
                                    }
                                }

                            }
                            break;

                            case HEARTBEAT: { // read

                                String message = null;
                                while ((message = session.read()) != null) {
                                    tracker.processData(message);
                                }
                            }
                            break;
                        }

                    }

                    if (key.isWritable()) {

                        Session session = (Session) key.attachment();

                        switch (session.state) {
                            case GREETING: { // write

                                if (session.drain()) {
                                    session.state(java.nio.channels.SelectionKey.OP_READ, State.LISTING);
                                }

                            }
                            break;

                            case LISTING: { // write

                                if (session.drain()) {

                                    // we haven't ready any URIs yet
                                    if (session.listed.size() == 0) {

                                        session.state(java.nio.channels.SelectionKey.OP_READ, State.LISTING);

                                    } else {

                                        session.trace("DONE");

                                        session.state(java.nio.channels.SelectionKey.OP_READ, State.HEARTBEAT);
                                    }
                                }
                            }
                            break;

                            case HEARTBEAT: { // write

                                if (session.drain()) {

                                    session.last = System.currentTimeMillis();

                                    session.trace("ping");

                                    session.state(java.nio.channels.SelectionKey.OP_READ, State.HEARTBEAT);

                                }

                            }
                            break;
                        }
                    }

                } catch (ClosedChannelException ex) {
                    synchronized (connect) {
                        Session session = (Session) key.attachment();
                        if (session.state != State.CLOSED) {
                            close(key);       
                        }
                    }
                } catch (IOException ex) {
                    Session session = (Session) key.attachment();
                    session.trace(ex.getClass().getSimpleName() + ": " + ex.getMessage());
                    close(key);
                }

            }

            for (SelectionKey key : selector.keys()) {
                Session session = (Session) key.attachment();

                try {
                    if (session != null) session.tick();
                } catch (IOException ex) {
                    close(key);
                }
            }

            synchronized (connect) {
                while (connect.size() > 0) {

                    URI uri = connect.removeFirst();

                    if (connections.containsKey(uri)) continue;

                    int port = uri.getPort();
                    String host = uri.getHost();

                    try {
                        println("open " + uri);

                        SocketChannel socketChannel = SocketChannel.open();
                        socketChannel.configureBlocking(false);

                        InetSocketAddress address = new InetSocketAddress(host, port);

                        socketChannel.connect(address);

                        Session session = new Session(socketChannel, address, uri);
                        session.ops(java.nio.channels.SelectionKey.OP_CONNECT);
                        session.trace("client");
                        connections.put(session.uri, session);
                       
                        // seen - needs to get maintained as "connected"
                        // TODO remove from seen
                    } catch (IOException e) {
                        log.warning("Error connecting to " + host + ":" + port, e);
                    }
                }
            }
        }
    }

    private ArrayList<URI> connections() {
        synchronized (connect) {
            ArrayList<URI> list = new ArrayList<URI>(connections.keySet());
            list.addAll(connect);
            return list;
        }
    }

    private void close(SelectionKey key) {
        Session session = (Session) key.attachment();

        session.state(0, State.CLOSED);

        if (session.hangup) {
            // This was a duplicate connection and was closed
            // do not remove this URI from the 'connections'
            // map as this particular session is not in that
            // map -- only the good session that will not be
            // closed is in there.
            session.trace("hungup");
        } else {
            session.trace("closed");
            synchronized (connect) {
                connections.remove(session.uri);
            }
        }

        hangup(key);
    }

    private void hangup(SelectionKey key) {
        key.cancel();
        try {
            key.channel().close();
        } catch (IOException cex) {
        }

    }


    public void connect(MultipointServer s) throws Exception {
        connect(s.port);
    }

    public void connect(int port) throws Exception {
        connect(URI.create("conn://localhost:" + port));
    }

    public void connect(URI uri) throws Exception {
        if (me.equals(uri)) return;

        synchronized (connect) {
            if (!connections.containsKey(uri) && !connect.contains(uri)) {
                connect.addLast(uri);
            }
        }
    }

    private void connected(Session session) {

        synchronized (connect) {
            Session duplicate = connections.get(session.uri);
//            Session duplicate = null;

            if (duplicate != null) {
                session.trace("duplicate");

                // At this point we know we have two sockets open
                // to the client, one created by them and one created
                // by us.  We will both have detected this situation
                // and know it needs fixing.  Only one of us can hangup

                Session[] sessions = {session, duplicate};
                Arrays.sort(sessions, new Comparator<Session>() {
                    // Goal: Keep the connection with the lowest port number
                    ///
                    // Low vs high is not very significant.  The critical
                    // part is that they both choose the same connection.
                    //
                    // Port numbers are seen on both sides.  There are two
                    // ports (one client and one server) for each connection.
                    //
                    // Both sides will agree to kill the connection with the
                    // lowest server port.  If those are the same, then both
                    // sides will agree to kill the connection with the lowest
                    // client port.  If those are the same, we still close a
                    // connection and hope for the best.  If both connections
                    // are killed we will try again next time another node
                    // lists the server and we notice we are not connected.
                    //
                    public int compare(Session a, Session b) {
                        int serverRank = server(a) - server(b);
                        if (serverRank != 0) return serverRank;
                        return client(a) - client(b);
                    }

                    private int server(Session a) {
                        Socket socket = a.channel.socket();
                        return a.client? socket.getPort(): socket.getLocalPort();
                    }

                    private int client(Session a) {
                        Socket socket = a.channel.socket();
                        return !a.client? socket.getPort(): socket.getLocalPort();
                    }
                });

                session = sessions[0];
                duplicate = sessions[1];

                session.trace(session + "@" + session.hashCode() + " KEEP");
                duplicate.trace(duplicate + "@" + duplicate.hashCode() + " KILL");

                duplicate.hangup = true;
            }

            connections.put(session.uri, session);
        }
    }

    private void println(String s) {
//        if (s.matches(".*(Listening|DONE|KEEP|KILL)")) {
//            System.out.format("%1$tH:%1$tM:%1$tS.%1$tL - %2$s\n", System.currentTimeMillis(), s);
//        }
    }
}
TOP

Related Classes of org.apache.openejb.server.discovery.MultipointServer$Session

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.