Package org.jwebsocket.netty.engines

Source Code of org.jwebsocket.netty.engines.NettyEngineHandler$SecureWebSocketConnectionListener

//  ---------------------------------------------------------------------------
//  jWebSocket - Copyright (c) 2010 jwebsocket.org
//  ---------------------------------------------------------------------------
//  This program is free software; you can redistribute it and/or modify it
//  under the terms of the GNU Lesser General Public License as published by the
//  Free Software Foundation; either version 3 of the License, or (at your
//  option) any later version.
//  This program is distributed in the hope that it will be useful, but WITHOUT
//  ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
//  FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for
//  more details.
//  You should have received a copy of the GNU Lesser General Public License along
//  with this program; if not, see <http://www.gnu.org/licenses/lgpl.html>.
//  ---------------------------------------------------------------------------
package org.jwebsocket.netty.engines;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import javolution.util.FastMap;
import org.apache.log4j.Logger;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.ChannelEvent;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.ChildChannelStateEvent;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.WriteCompletionEvent;
import org.jboss.netty.channel.group.ChannelGroup;
import org.jboss.netty.channel.group.DefaultChannelGroup;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrame;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameDecoder;
import org.jboss.netty.handler.codec.http.websocket.WebSocketFrameEncoder;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.util.CharsetUtil;
import org.jwebsocket.api.EngineConfiguration;
import org.jwebsocket.api.WebSocketConnector;
import org.jwebsocket.config.JWebSocketCommonConstants;
import org.jwebsocket.kit.CloseReason;
import org.jwebsocket.kit.RawPacket;
import org.jwebsocket.kit.RequestHeader;
import org.jwebsocket.kit.WebSocketRuntimeException;
import org.jwebsocket.logging.Logging;
import org.jwebsocket.netty.connectors.NettyConnector;
import org.jwebsocket.netty.http.HttpHeaders;

/**
* Handler class for the <tt>NettyEngine</tt> that recieves the events based on
* event types and notifies the client connectors. This handler also handles the
* initial handshaking for WebSocket connection with a appropriate hand shake
* response. This handler is created for each new connection channel.
* <p>
* Once the handshaking is successful after sending the handshake {@code
* HttpResponse} it replaces the {@code HttpRequestDecoder} and {@code
* HttpResponseEncoder} from the channel pipeline with {@code
* WebSocketFrameDecoder} as WebSocket frame data decoder and {@code
* WebSocketFrameEncoder} as WebSocket frame data encoder. Also it starts the
* <tt>NettyConnector</tt>.
* </p>
*
* @author <a href="http://www.purans.net/">Puran Singh</a>
* @version $Id: NettyEngineHandler.java 613 2010-07-01 07:13:29Z mailtopuran@gmail.com $
*/
public class NettyEngineHandler extends SimpleChannelUpstreamHandler {

    private static Logger log = Logging.getLogger(NettyEngineHandler.class);

    private NettyEngine engine = null;

    private WebSocketConnector connector = null;

    private ChannelHandlerContext context = null;

    private static final ChannelGroup channels = new DefaultChannelGroup();

    private static final String CONTENT_LENGTH = "Content-Length";

    private static final String ARGS = "args";
    private static final String ORIGIN = "origin";
    private static final String LOCATION = "location";
    private static final String PATH = "path";
    private static final String SEARCH_STRING = "searchString";
    private static final String HOST = "host";

    public NettyEngineHandler(NettyEngine aEngine) {
        this.engine = aEngine;
    }

    /**
     * {@inheritDoc }
     */
    @Override
    public void channelBound(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.channelBound(ctx, e);
    }

    /**
     * {@inheritDoc }
     */
    @Override
    public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.channelClosed(ctx, e);
    }

    /**
     * {@inheritDoc }
     */
    @Override
    public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.channelConnected(ctx, e);
    }

    /**
     * {@inheritDoc }
     */
    @Override
    public void channelDisconnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        if (log.isDebugEnabled()) {
            log.debug("Channel is disconnected");
        }
        // remove the channel
        channels.remove(e.getChannel());

        this.context = ctx;
        super.channelDisconnected(ctx, e);
        engine.connectorStopped(connector, CloseReason.CLIENT);
    }

    /**
     * {@inheritDoc }
     */
    @Override
    public void channelInterestChanged(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.channelInterestChanged(ctx, e);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.channelOpen(ctx, e);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void channelUnbound(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.channelUnbound(ctx, e);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void childChannelClosed(ChannelHandlerContext ctx, ChildChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.childChannelClosed(ctx, e);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void childChannelOpen(ChannelHandlerContext ctx, ChildChannelStateEvent e) throws Exception {
        this.context = ctx;
        super.childChannelOpen(ctx, e);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
        e.getCause().printStackTrace();
        this.context = ctx;
        if (log.isDebugEnabled()) {
            log.debug("Channel is disconnected:" + e.getCause().getLocalizedMessage());
        }
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void handleUpstream(ChannelHandlerContext ctx, ChannelEvent arg1) throws Exception {
        this.context = ctx;
        super.handleUpstream(ctx, arg1);
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
        this.context = ctx;
        if (log.isDebugEnabled()) {
            log.debug("message received in the engine handler");
        }
        Object msg = e.getMessage();
        if (msg instanceof HttpRequest) {
            handleHttpRequest(ctx, (HttpRequest) msg);
        } else if (msg instanceof WebSocketFrame) {
            handleWebSocketFrame(ctx, (WebSocketFrame) msg);
        }
    }

    /**
     * private method that sends the handshake response for WebSocket connection
     *
     * @param ctx the channel context
     * @param req http request object
     * @param res http response object
     */
    private void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) {
        // Generate an error page if response status code is not OK (200).
        if (res.getStatus().getCode() != 200) {
            res.setContent(ChannelBuffers.copiedBuffer(res.getStatus().toString(), CharsetUtil.UTF_8));
            setContentLength(res, res.getContent().readableBytes());
        }

        // Send the response and close the connection if necessary.
        ChannelFuture f = ctx.getChannel().write(res);
        if (!isKeepAlive(req) || res.getStatus().getCode() != 200) {
            f.addListener(ChannelFutureListener.CLOSE);
        }
    }

    /**
     * Check if the request header has Keep-Alive
     *
     * @param req the http request object
     * @return {@code true} if keep-alive is set in the header {@code false}
     *         otherwise
     */
    private boolean isKeepAlive(HttpRequest req) {
        String keepAlive = req.getHeader(HttpHeaders.Values.KEEP_ALIVE);
        if (keepAlive != null && keepAlive.length() > 0) {
            return true;
        } else {
            // TODO: Keep-Alive value is like 'timeout=15, max=500'
            return false;
        }
    }

    /**
     * Set the content length in the response
     *
     * @param res the http response object
     * @param readableBytes the length of the bytes
     */
    private void setContentLength(HttpResponse res, int readableBytes) {
        res.setHeader(CONTENT_LENGTH, readableBytes);
    }

    /**
     * private method that handles the web socket frame data, this method is
     * used only after the WebSocket connection is established.
     *
     * @param ctx the channel handler context
     * @param msg the web socket frame data
     */
    private void handleWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame msg) throws WebSocketRuntimeException {
        String textData = "";
        if (msg.isBinary()) {
            // TODO: handle binary data
        } else if (msg.isText()) {
            textData = msg.getTextData();
        } else {
            throw new WebSocketRuntimeException("Frame Doesn't contain any type of data");
        }
        engine.processPacket(connector, new RawPacket(textData));
    }

    /**
     * Handles the initial HTTP request for handshaking if the http request
     * contains Upgrade header value as WebSocket then this method sends the
     * handshake response and also fires the events on client connector.
     *
     * @param ctx the channel handler context
     * @param req  the request message
     */
    private void handleHttpRequest(ChannelHandlerContext ctx, HttpRequest req) {
        // Allow only GET methods.
        if (req.getMethod() != HttpMethod.GET) {
            sendHttpResponse(ctx, req, new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN));
            return;
        }
        // Serve the WebSocket handshake request.
        if (HttpHeaders.Values.UPGRADE.equalsIgnoreCase(req.getHeader(HttpHeaders.Names.CONNECTION)) &&
                HttpHeaders.Values.WEBSOCKET.equalsIgnoreCase(req.getHeader(HttpHeaders.Names.UPGRADE))) {
            // Create the WebSocket handshake response.
            HttpResponse response = null;
            try {
                response = constructHandShakeResponse(req, ctx);
            } catch (NoSuchAlgorithmException e) {
                // better to close the channel
                log.debug("Channel is disconnected");
                ctx.getChannel().close();
            }

            // write the response
            ctx.getChannel().write(response);

            channels.add(ctx.getChannel());

            // since handshaking is done, replace the encoder/decoder with
            // web socket data frame encoder/decoder
            ChannelPipeline p = ctx.getChannel().getPipeline();
            p.remove("aggregator");
            EngineConfiguration config = engine.getConfiguration();
            if (config == null || config.getMaxFramesize() == 0) {
                p.replace("decoder", "jwsdecoder", new WebSocketFrameDecoder(JWebSocketCommonConstants.DEFAULT_MAX_FRAME_SIZE));
            } else {
                p.replace("decoder", "jwsdecoder", new WebSocketFrameDecoder(config.getMaxFramesize()));
            }
            p.replace("encoder", "jwsencoder", new WebSocketFrameEncoder());
           
            //if the WebSocket connection URI is wss then start SSL TLS handshaking
            if (req.getUri().startsWith("wss:")) {
                // Get the SslHandler in the current pipeline.
                final SslHandler sslHandler = ctx.getPipeline().get(SslHandler.class);
                // Get notified when SSL handshake is done.
                ChannelFuture handshakeFuture = sslHandler.handshake();
                handshakeFuture.addListener(new SecureWebSocketConnectionListener(sslHandler));
            }
            // initialize the connector
            connector = initializeConnector(ctx, req);

            return;
        }

        // Send an error page otherwise.
        sendHttpResponse(ctx, req, new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN));
    }

    /**
     * Constructs the <tt>HttpResponse</tt> object for the handshake response
     *
     * @param req the http request object
     * @param ctx the channel handler context
     * @return the http handshake response
     * @throws NoSuchAlgorithmException
     */
    private HttpResponse constructHandShakeResponse(HttpRequest req, ChannelHandlerContext ctx) throws NoSuchAlgorithmException {
        HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, new HttpResponseStatus(101, "Web Socket Protocol Handshake"));
        res.addHeader(HttpHeaders.Names.UPGRADE, HttpHeaders.Values.WEBSOCKET);
        res.addHeader(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.UPGRADE);

        // Fill in the headers and contents depending on handshake method.
        if (req.containsHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY1) && req.containsHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY2)) {
            // New handshake method with a challenge:
            res.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_ORIGIN, req.getHeader(HttpHeaders.Names.ORIGIN));
            res.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_LOCATION, getWebSocketLocation(req));
            String protocol = req.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL);
            if (protocol != null) {
                res.addHeader(HttpHeaders.Names.SEC_WEBSOCKET_PROTOCOL, protocol);
            }

            // Calculate the answer of the challenge.
            String key1 = req.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY1);
            String key2 = req.getHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY2);
            int a = (int) (Long.parseLong(key1.replaceAll("[^0-9]", "")) / key1.replaceAll("[^ ]", "").length());
            int b = (int) (Long.parseLong(key2.replaceAll("[^0-9]", "")) / key2.replaceAll("[^ ]", "").length());
            long c = req.getContent().readLong();
            ChannelBuffer input = ChannelBuffers.buffer(16);
            input.writeInt(a);
            input.writeInt(b);
            input.writeLong(c);
            ChannelBuffer output = ChannelBuffers.wrappedBuffer(MessageDigest.getInstance("MD5").digest(input.array()));
            res.setContent(output);
        } else {
            // Old handshake method with no challenge:
            res.addHeader(HttpHeaders.Names.WEBSOCKET_ORIGIN, req.getHeader(HttpHeaders.Names.ORIGIN));
            res.addHeader(HttpHeaders.Names.WEBSOCKET_LOCATION, getWebSocketLocation(req));
            String protocol = req.getHeader(HttpHeaders.Names.WEBSOCKET_PROTOCOL);
            if (protocol != null) {
                res.addHeader(HttpHeaders.Names.WEBSOCKET_PROTOCOL, protocol);
            }
        }
        return res;

    }

    /**
     * Initialize the {@code NettyConnector} after initial handshaking is
     * successfull.
     *
     * @param ctx the channel handler context
     * @param req the http request object
     */
    private WebSocketConnector initializeConnector(ChannelHandlerContext ctx, HttpRequest req) {

        RequestHeader header = getRequestHeader(req);
        int lSessionTimeout = header.getTimeout(JWebSocketCommonConstants.DEFAULT_TIMEOUT);
        if (lSessionTimeout > 0) {
            ctx.getChannel().getConfig().setConnectTimeoutMillis(lSessionTimeout);
        }
        // create connector
        WebSocketConnector theConnector = new NettyConnector(engine, this);
        theConnector.setHeader(header);

        engine.getConnectors().put(theConnector.getId(), theConnector);
        theConnector.startConnector();
        // allow descendant classes to handle connector started event
        engine.connectorStarted(theConnector);
        return theConnector;

    }

    /**
     * Construct the request header to save it in the connector
     *
     * @param req the http request header
     * @return the request header
     */
    private RequestHeader getRequestHeader(HttpRequest req) {
        RequestHeader header = new RequestHeader();
        FastMap<String, String> args = new FastMap<String, String>();
        String searchString = "";
        String path = req.getUri();

        // isolate search string
        int pos = path.indexOf(JWebSocketCommonConstants.PATHARG_SEPARATOR);
        if (pos >= 0) {
            searchString = path.substring(pos + 1);
            if (searchString.length() > 0) {
                String[] lArgs = searchString.split(JWebSocketCommonConstants.ARGARG_SEPARATOR);
                for (int i = 0; i < lArgs.length; i++) {
                    String[] lKeyValuePair = lArgs[i].split(JWebSocketCommonConstants.KEYVAL_SEPARATOR, 2);
                    if (lKeyValuePair.length == 2) {
                        args.put(lKeyValuePair[0], lKeyValuePair[1]);
                        if (log.isDebugEnabled()) {
                            log.debug("arg" + i + ": " + lKeyValuePair[0] + "=" + lKeyValuePair[1]);
                        }
                    }
                }
            }
        }
        // set default sub protocol if none passed
        if (args.get("prot") == null) {
            args.put("prot", JWebSocketCommonConstants.SUB_PROT_DEFAULT);
        }
        header.put(ARGS, args);
        header.put(ORIGIN, req.getHeader(HttpHeaders.Names.ORIGIN));
        header.put(LOCATION, getWebSocketLocation(req));
        header.put(PATH, req.getUri());

        header.put(SEARCH_STRING, searchString);
        header.put(HOST, req.getHeader(HttpHeaders.Names.HOST));
        return header;
    }

    /**
     * {@inheritDoc}
     */
    @Override
    public void writeComplete(ChannelHandlerContext ctx, WriteCompletionEvent e) throws Exception {
        super.writeComplete(ctx, e);
    }

    /**
     * Returns the web socket location URL
     *
     * @param req the http request object
     * @return the location url string
     */
    private String getWebSocketLocation(HttpRequest req) {
        //TODO: fix this URL for wss: (secure)
        String location = "ws://" + req.getHeader(HttpHeaders.Names.HOST) + req.getUri();
        return location;
    }

    /**
     * Returns the channel context
     *
     * @return the channel context
     */
    public ChannelHandlerContext getChannelHandlerContext() {
        return context;
    }

    /**
     * Listener class for SSL TLS handshake completion.
     */
    private static final class SecureWebSocketConnectionListener implements ChannelFutureListener {

        private final SslHandler sslHandler;

        SecureWebSocketConnectionListener(SslHandler sslHandler) {
            this.sslHandler = sslHandler;
        }

        public void operationComplete(ChannelFuture future) throws Exception {
            if (future.isSuccess()) {
                // that means SSL handshaking is done.
            } else {
                future.getChannel().close();
            }
        }
    }
}
TOP

Related Classes of org.jwebsocket.netty.engines.NettyEngineHandler$SecureWebSocketConnectionListener

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.