Package org.apache.tomcat.util.net

Source Code of org.apache.tomcat.util.net.SecureNio2Channel$ReadCompletionHandler

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.util.net;

import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.nio.channels.ReadPendingException;
import java.nio.channels.WritePendingException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;

import org.apache.tomcat.util.res.StringManager;

/**
* Implementation of a secure socket channel for NIO2.
*/
public class SecureNio2Channel extends Nio2Channel  {

    protected static final StringManager sm = StringManager.getManager("org.apache.tomcat.util.net.res");

    protected ByteBuffer netInBuffer;
    protected ByteBuffer netOutBuffer;

    protected SSLEngine sslEngine;
    protected final Nio2Endpoint endpoint;

    protected boolean handshakeComplete;
    protected HandshakeStatus handshakeStatus; //gets set by handshake

    protected boolean closed;
    protected boolean closing;
    protected volatile boolean readPending;
    protected volatile boolean writePending;

    private CompletionHandler<Integer, SocketWrapper<Nio2Channel>> handshakeReadCompletionHandler;
    private CompletionHandler<Integer, SocketWrapper<Nio2Channel>> handshakeWriteCompletionHandler;

    public SecureNio2Channel(SSLEngine engine, ApplicationBufferHandler bufHandler,
            Nio2Endpoint endpoint0) {
        super(bufHandler);
        sslEngine = engine;
        endpoint = endpoint0;
        int netBufSize = sslEngine.getSession().getPacketBufferSize();
        if (endpoint.getSocketProperties().getDirectSslBuffer()) {
            netInBuffer = ByteBuffer.allocateDirect(netBufSize);
            netOutBuffer = ByteBuffer.allocateDirect(netBufSize);
        } else {
            netInBuffer = ByteBuffer.allocate(netBufSize);
            netOutBuffer = ByteBuffer.allocate(netBufSize);
        }
        handshakeReadCompletionHandler = new CompletionHandler<Integer, SocketWrapper<Nio2Channel>>() {
            @Override
            public void completed(Integer result, SocketWrapper<Nio2Channel> attachment) {
                if (result.intValue() < 0) {
                    failed(new EOFException(), attachment);
                    return;
                }
                endpoint.processSocket(attachment, SocketStatus.OPEN_READ, false);
            }
            @Override
            public void failed(Throwable exc, SocketWrapper<Nio2Channel> attachment) {
                endpoint.closeSocket(attachment, SocketStatus.ERROR);
            }
        };
        handshakeWriteCompletionHandler = new CompletionHandler<Integer, SocketWrapper<Nio2Channel>>() {
            @Override
            public void completed(Integer result, SocketWrapper<Nio2Channel> attachment) {
                if (result.intValue() < 0) {
                    failed(new EOFException(), attachment);
                    return;
                }
                endpoint.processSocket(attachment, SocketStatus.OPEN_WRITE, false);
            }
            @Override
            public void failed(Throwable exc, SocketWrapper<Nio2Channel> attachment) {
                endpoint.closeSocket(attachment, SocketStatus.ERROR);
            }
        };
    }

    public void setSSLEngine(SSLEngine engine) {
        this.sslEngine = engine;
    }

    @Override
    public void reset(AsynchronousSocketChannel channel, SocketWrapper<Nio2Channel> socket)
            throws IOException {
        super.reset(channel, socket);
        netOutBuffer.position(0);
        netOutBuffer.limit(0);
        netInBuffer.position(0);
        netInBuffer.limit(0);
        handshakeComplete = false;
        closed = false;
        closing = false;
        readPending = false;
        writePending = false;
        //initiate handshake
        sslEngine.beginHandshake();
        handshakeStatus = sslEngine.getHandshakeStatus();
    }

    @Override
    public int getBufferSize() {
        int size = super.getBufferSize();
        size += netInBuffer!=null?netInBuffer.capacity():0;
        size += netOutBuffer!=null?netOutBuffer.capacity():0;
        return size;
    }

    private class FutureFlush implements Future<Boolean> {
        private Future<Integer> integer;
        protected FutureFlush(Future<Integer> integer) {
            this.integer = integer;
        }
        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return integer.cancel(mayInterruptIfRunning);
        }
        @Override
        public boolean isCancelled() {
            return integer.isCancelled();
        }
        @Override
        public boolean isDone() {
            return integer.isDone();
        }
        @Override
        public Boolean get() throws InterruptedException,
                ExecutionException {
            try {
                int result = integer.get().intValue();
                return Boolean.valueOf(result >= 0);
            } finally {
                writePending = false;
            }
        }
        @Override
        public Boolean get(long timeout, TimeUnit unit)
                throws InterruptedException, ExecutionException,
                TimeoutException {
            try {
                int result = integer.get(timeout, unit).intValue();
                return Boolean.valueOf(result >= 0);
            } finally {
                writePending = false;
            }
        }
    }

    /**
     * Flush the channel.
     *
     * @return <code>true</code> if the network buffer has been flushed out and
     *         is empty else <code>false</code> (as a future)
     */
    @Override
    public Future<Boolean> flush() {
        if (writePending) {
            throw new WritePendingException();
        } else {
            writePending = true;
        }
        return new FutureFlush(sc.write(netOutBuffer));
    }

    /**
     * Performs SSL handshake, non blocking, but performs NEED_TASK on the same thread.<br>
     * Hence, you should never call this method using your Acceptor thread, as you would slow down
     * your system significantly.<br>
     * The return for this operation is 0 if the handshake is complete and a positive value if it is not complete.
     * In the event of a positive value coming back, reregister the selection key for the return values interestOps.
     *
     * @return int - 0 if hand shake is complete, otherwise it returns a SelectionKey interestOps value
     * @throws IOException
     */
    @Override
    public int handshake() throws IOException {
        return handshakeInternal(true);
    }

    protected int handshakeInternal(boolean async) throws IOException {
        if (handshakeComplete)
            return 0; //we have done our initial handshake

        SSLEngineResult handshake = null;

        while (!handshakeComplete) {
            switch (handshakeStatus) {
                case NOT_HANDSHAKING: {
                    //should never happen
                    throw new IOException(sm.getString("channel.nio.ssl.notHandshaking"));
                }
                case FINISHED: {
                    //we are complete if we have delivered the last package
                    handshakeComplete = !netOutBuffer.hasRemaining();
                    //return 0 if we are complete, otherwise we still have data to write
                    if (handshakeComplete) {
                        return 0;
                    } else {
                        if (async) {
                            sc.write(netOutBuffer, socket, handshakeWriteCompletionHandler);
                        } else {
                            try {
                                sc.write(netOutBuffer).get(endpoint.getSoTimeout(), TimeUnit.MILLISECONDS);
                            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                                throw new IOException(sm.getString("channel.nio.ssl.handhakeError"));
                            }
                        }
                        return 1;
                    }
                }
                case NEED_WRAP: {
                    //perform the wrap function
                    handshake = handshakeWrap();
                    if (handshake.getStatus() == Status.OK){
                        if (handshakeStatus == HandshakeStatus.NEED_TASK)
                            handshakeStatus = tasks();
                    } else {
                        //wrap should always work with our buffers
                        throw new IOException(sm.getString("channel.nio.ssl.unexpectedStatusDuringWrap", handshake.getStatus()));
                    }
                    if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || netOutBuffer.remaining() > 0) {
                        //should actually return OP_READ if we have NEED_UNWRAP
                        if (async) {
                            sc.write(netOutBuffer, socket, handshakeWriteCompletionHandler);
                        } else {
                            try {
                                sc.write(netOutBuffer).get(endpoint.getSoTimeout(), TimeUnit.MILLISECONDS);
                            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                                throw new IOException(sm.getString("channel.nio.ssl.handhakeError"));
                            }
                        }
                        return 1;
                    }
                    //fall down to NEED_UNWRAP on the same call, will result in a
                    //BUFFER_UNDERFLOW if it needs data
                }
                //$FALL-THROUGH$
                case NEED_UNWRAP: {
                    //perform the unwrap function
                    handshake = handshakeUnwrap();
                    if (handshake.getStatus() == Status.OK) {
                        if (handshakeStatus == HandshakeStatus.NEED_TASK)
                            handshakeStatus = tasks();
                    } else if (handshake.getStatus() == Status.BUFFER_UNDERFLOW) {
                        //read more data, reregister for OP_READ
                        if (async) {
                            sc.read(netInBuffer, socket, handshakeReadCompletionHandler);
                        } else {
                            try {
                                sc.read(netInBuffer).get(endpoint.getSoTimeout(), TimeUnit.MILLISECONDS);
                            } catch (InterruptedException | ExecutionException | TimeoutException e) {
                                throw new IOException(sm.getString("channel.nio.ssl.handhakeError"));
                            }
                        }
                        return 1;
                    } else {
                        throw new IOException(sm.getString("channel.nio.ssl.unexpectedStatusDuringUnwrap", handshakeStatus));
                    }
                    break;
                }
                case NEED_TASK: {
                    handshakeStatus = tasks();
                    break;
                }
                default: throw new IllegalStateException(sm.getString("channel.nio.ssl.invalidStatus", handshakeStatus));
            }
        }
        //return 0 if we are complete, otherwise recurse to process the task
        return handshakeComplete ? 0 : handshakeInternal(async);
    }

    /**
     * Force a blocking handshake to take place for this key.
     * This requires that both network and application buffers have been emptied out prior to this call taking place, or a
     * IOException will be thrown.
     * @throws IOException - if an IO exception occurs or if application or network buffers contain data
     * @throws java.net.SocketTimeoutException - if a socket operation timed out
     */
    public void rehandshake() throws IOException {
        //validate the network buffers are empty
        if (netInBuffer.position() > 0 && netInBuffer.position() < netInBuffer.limit()) throw new IOException(sm.getString("channel.nio.ssl.netInputNotEmpty"));
        if (netOutBuffer.position() > 0 && netOutBuffer.position() < netOutBuffer.limit()) throw new IOException(sm.getString("channel.nio.ssl.netOutputNotEmpty"));
        ByteBuffer readBuffer = getBufHandler().getReadBuffer();
        ByteBuffer writeBuffer = getBufHandler().getWriteBuffer();
        if (readBuffer.position() > 0 && readBuffer.position() < readBuffer.limit()) throw new IOException(sm.getString("channel.nio.ssl.appInputNotEmpty"));
        if (writeBuffer.position() > 0 && writeBuffer.position() < writeBuffer.limit()) throw new IOException(sm.getString("channel.nio.ssl.appOutputNotEmpty"));

        netOutBuffer.position(0);
        netOutBuffer.limit(0);
        netInBuffer.position(0);
        netInBuffer.limit(0);
        readBuffer.clear();
        writeBuffer.clear();

        handshakeComplete = false;
        //initiate handshake
        sslEngine.beginHandshake();
        handshakeStatus = sslEngine.getHandshakeStatus();

        boolean handshaking = true;
        try {
            while (handshaking) {
                int hsStatus = handshakeInternal(false);
                switch (hsStatus) {
                    case -1 : throw new EOFException(sm.getString("channel.nio.ssl.eofDuringHandshake"));
                    case  0 : handshaking = false; break;
                    default : // Some blocking IO occurred, so iterate
                }
            }
        } catch (IOException x) {
            throw x;
        } catch (Exception cx) {
            IOException x = new IOException(cx);
            throw x;
        }
    }


    /**
     * Executes all the tasks needed on the same thread.
     * @return HandshakeStatus
     */
    protected SSLEngineResult.HandshakeStatus tasks() {
        Runnable r = null;
        while ( (r = sslEngine.getDelegatedTask()) != null) {
            r.run();
        }
        return sslEngine.getHandshakeStatus();
    }

    /**
     * Performs the WRAP function
     * @return SSLEngineResult
     * @throws IOException
     */
    protected SSLEngineResult handshakeWrap() throws IOException {
        //this should never be called with a network buffer that contains data
        //so we can clear it here.
        netOutBuffer.clear();
        //perform the wrap
        SSLEngineResult result = sslEngine.wrap(bufHandler.getWriteBuffer(), netOutBuffer);
        //prepare the results to be written
        netOutBuffer.flip();
        //set the status
        handshakeStatus = result.getHandshakeStatus();
        return result;
    }

    /**
     * Perform handshake unwrap
     * @return SSLEngineResult
     * @throws IOException
     */
    protected SSLEngineResult handshakeUnwrap() throws IOException {

        if (netInBuffer.position() == netInBuffer.limit()) {
            //clear the buffer if we have emptied it out on data
            netInBuffer.clear();
        }
        SSLEngineResult result;
        boolean cont = false;
        //loop while we can perform pure SSLEngine data
        do {
            //prepare the buffer with the incoming data
            netInBuffer.flip();
            //call unwrap
            result = sslEngine.unwrap(netInBuffer, bufHandler.getReadBuffer());
            //compact the buffer, this is an optional method, wonder what would happen if we didn't
            netInBuffer.compact();
            //read in the status
            handshakeStatus = result.getHandshakeStatus();
            if (result.getStatus() == SSLEngineResult.Status.OK &&
                 result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                //execute tasks if we need to
                handshakeStatus = tasks();
            }
            //perform another unwrap?
            cont = result.getStatus() == SSLEngineResult.Status.OK &&
                   handshakeStatus == HandshakeStatus.NEED_UNWRAP;
        } while (cont);
        return result;
    }

    /**
     * Sends a SSL close message, will not physically close the connection here.<br>
     * To close the connection, you could do something like
     * <pre><code>
     *   close();
     *   while (isOpen() &amp;&amp; !myTimeoutFunction()) Thread.sleep(25);
     *   if ( isOpen() ) close(true); //forces a close if you timed out
     * </code></pre>
     * @throws IOException if an I/O error occurs
     * @throws IOException if there is data on the outgoing network buffer and we are unable to flush it
     */
    @Override
    public void close() throws IOException {
        if (closing) return;
        closing = true;
        sslEngine.closeOutbound();

        try {
            if (!flush().get(endpoint.getSoTimeout(), TimeUnit.MILLISECONDS).booleanValue()) {
                throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"));
            }
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"), e);
        } catch (WritePendingException e) {
            throw new IOException(sm.getString("channel.nio.ssl.pendingWriteDuringClose"), e);
        }
        //prep the buffer for the close message
        netOutBuffer.clear();
        //perform the close, since we called sslEngine.closeOutbound
        SSLEngineResult handshake = sslEngine.wrap(getEmptyBuf(), netOutBuffer);
        //we should be in a close state
        if (handshake.getStatus() != SSLEngineResult.Status.CLOSED) {
            throw new IOException(sm.getString("channel.nio.ssl.invalidCloseState"));
        }
        //prepare the buffer for writing
        netOutBuffer.flip();
        //if there is data to be written
        try {
            if (!flush().get(endpoint.getSoTimeout(), TimeUnit.MILLISECONDS).booleanValue()) {
                throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"));
            }
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
            throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"), e);
        } catch (WritePendingException e) {
            throw new IOException(sm.getString("channel.nio.ssl.pendingWriteDuringClose"), e);
        }

        //is the channel closed?
        closed = (!netOutBuffer.hasRemaining() && (handshake.getHandshakeStatus() != HandshakeStatus.NEED_WRAP));
    }

    /**
     * Force a close, can throw an IOException
     * @param force boolean
     * @throws IOException
     */
    @Override
    public void close(boolean force) throws IOException {
        try {
            close();
        } finally {
            if ( force || closed ) {
                closed = true;
                sc.close();
            }
        }
    }

    private class FutureRead implements Future<Integer> {
        private ByteBuffer dst;
        public FutureRead(ByteBuffer dst) {
            this.dst = dst;
        }
        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return false;
        }
        @Override
        public boolean isCancelled() {
            return false;
        }
        @Override
        public boolean isDone() {
            return true;
        }
        @Override
        public Integer get() throws InterruptedException, ExecutionException {
            try {
                return unwrap(netInBuffer.position());
            } finally {
                readPending = false;
            }
        }
        @Override
        public Integer get(long timeout, TimeUnit unit)
                throws InterruptedException, ExecutionException,
                TimeoutException {
            try {
                return unwrap(netInBuffer.position());
            } finally {
                readPending = false;
            }
        }
        protected Integer unwrap(int netread) throws ExecutionException {
            //are we in the middle of closing or closed?
            if (closing || closed)
                return Integer.valueOf(-1);
            //did we reach EOF? if so send EOF up one layer.
            if (netread == -1)
                return Integer.valueOf(-1);
            //the data read
            int read = 0;
            //the SSL engine result
            SSLEngineResult unwrap;
            do {
                //prepare the buffer
                netInBuffer.flip();
                //unwrap the data
                try {
                    unwrap = sslEngine.unwrap(netInBuffer, dst);
                } catch (SSLException e) {
                    throw new ExecutionException(e);
                }
                //compact the buffer
                netInBuffer.compact();
                if (unwrap.getStatus()==Status.OK || unwrap.getStatus()==Status.BUFFER_UNDERFLOW) {
                    //we did receive some data, add it to our total
                    read += unwrap.bytesProduced();
                    //perform any tasks if needed
                    if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK)
                        tasks();
                    //if we need more network data, then bail out for now.
                    if (unwrap.getStatus() == Status.BUFFER_UNDERFLOW)
                        break;
                } else if (unwrap.getStatus()==Status.BUFFER_OVERFLOW && read > 0) {
                    //buffer overflow can happen, if we have read data, then
                    //empty out the dst buffer before we do another read
                    break;
                } else {
                    //here we should trap BUFFER_OVERFLOW and call expand on the buffer
                    //for now, throw an exception, as we initialized the buffers
                    //in the constructor
                    throw new ExecutionException(new IOException(sm.getString("channel.nio.ssl.unwrapFail", unwrap.getStatus())));
                }
            } while ((netInBuffer.position() != 0)); //continue to unwrapping as long as the input buffer has stuff
            return Integer.valueOf(read);
        }
    }

    private class FutureNetRead extends FutureRead {
        private Future<Integer> integer;
        protected FutureNetRead(ByteBuffer dst) {
            super(dst);
            this.integer = sc.read(netInBuffer);
        }
        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return integer.cancel(mayInterruptIfRunning);
        }
        @Override
        public boolean isCancelled() {
            return integer.isCancelled();
        }
        @Override
        public boolean isDone() {
            return integer.isDone();
        }
        @Override
        public Integer get() throws InterruptedException, ExecutionException {
            try {
                int netread = integer.get().intValue();
                return unwrap(netread);
            } finally {
                readPending = false;
            }
        }
        @Override
        public Integer get(long timeout, TimeUnit unit)
                throws InterruptedException, ExecutionException,
                TimeoutException {
            try {
                int netread = integer.get(timeout, unit).intValue();
                return unwrap(netread);
            } finally {
                readPending = false;
            }
        }
    }

    /**
     * Reads a sequence of bytes from this channel into the given buffer.
     *
     * @param dst The buffer into which bytes are to be transferred
     * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached end-of-stream
     * @throws IllegalStateException if the handshake was not completed
     */
    @Override
    public Future<Integer> read(ByteBuffer dst) {
        if (readPending) {
            throw new ReadPendingException();
        } else {
            readPending = true;
        }
        //did we finish our handshake?
        if (!handshakeComplete)
            throw new IllegalStateException(sm.getString("channel.nio.ssl.incompleteHandshake"));
        if (netInBuffer.position() > 0) {
            return new FutureRead(dst);
        } else {
            return new FutureNetRead(dst);
        }
    }

    private class FutureWrite implements Future<Integer> {
        private ByteBuffer src;
        private Future<Integer> integer = null;
        private int written = 0;
        private Throwable t = null;
        protected FutureWrite(ByteBuffer src) {
            //are we closing or closed?
            if (closing || closed) {
                t = new IOException(sm.getString("channel.nio.ssl.closing"));
                return;
            }
            this.src = src;
            wrap();
        }
        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return integer.cancel(mayInterruptIfRunning);
        }
        @Override
        public boolean isCancelled() {
            return integer.isCancelled();
        }
        @Override
        public boolean isDone() {
            return integer.isDone();
        }
        @Override
        public Integer get() throws InterruptedException, ExecutionException {
            if (t != null) {
                writePending = false;
                throw new ExecutionException(t);
            }
            integer.get();
            if (written == 0) {
                wrap();
                return get();
            } else {
                writePending = false;
                return Integer.valueOf(written);
            }
        }
        @Override
        public Integer get(long timeout, TimeUnit unit)
                throws InterruptedException, ExecutionException,
                TimeoutException {
            if (t != null) {
                writePending = false;
                throw new ExecutionException(t);
            }
            integer.get(timeout, unit);
            if (written == 0) {
                wrap();
                return get(timeout, unit);
            } else {
                writePending = false;
                return Integer.valueOf(written);
            }
        }
        protected void wrap() {
            //The data buffer should be empty, we can reuse the entire buffer.
            netOutBuffer.clear();
            try {
                SSLEngineResult result = sslEngine.wrap(src, netOutBuffer);
                written = result.bytesConsumed();
                netOutBuffer.flip();
                if (result.getStatus() == Status.OK) {
                    if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK)
                        tasks();
                } else {
                    t = new IOException(sm.getString("channel.nio.ssl.wrapFail", result.getStatus()));
                }
                integer = sc.write(netOutBuffer);
            } catch (SSLException e) {
                t = e;
            }
        }
    }

    /**
     * Writes a sequence of bytes to this channel from the given buffer.
     *
     * @param src The buffer from which bytes are to be retrieved
     * @return The number of bytes written, possibly zero
     */
    @Override
    public Future<Integer> write(ByteBuffer src) {
        if (writePending) {
            throw new WritePendingException();
        } else {
            writePending = true;
        }
        return new FutureWrite(src);
    }

    private class ReadCompletionHandler<A> implements CompletionHandler<Integer, A> {
        protected ByteBuffer dst;
        protected CompletionHandler<Integer, ? super A> handler;
        protected ReadCompletionHandler(ByteBuffer dst, CompletionHandler<Integer, ? super A> handler) {
            this.dst = dst;
            this.handler = handler;
        }

        @Override
        public void completed(Integer nBytes, A attach) {
            if (nBytes.intValue() < 0) {
                failed(new EOFException(), attach);
            } else {
                try {
                    //the data read
                    int read = 0;
                    //the SSL engine result
                    SSLEngineResult unwrap;
                    do {
                        //prepare the buffer
                        netInBuffer.flip();
                        //unwrap the data
                        unwrap = sslEngine.unwrap(netInBuffer, dst);
                        //compact the buffer
                        netInBuffer.compact();
                        if (unwrap.getStatus() == Status.OK || unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
                            //we did receive some data, add it to our total
                            read += unwrap.bytesProduced();
                            //perform any tasks if needed
                            if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK)
                                tasks();
                            //if we need more network data, then bail out for now.
                            if (unwrap.getStatus() == Status.BUFFER_UNDERFLOW)
                                break;
                        } else if (unwrap.getStatus() == Status.BUFFER_OVERFLOW && read > 0) {
                            //buffer overflow can happen, if we have read data, then
                            //empty out the dst buffer before we do another read
                            break;
                        } else {
                            //here we should trap BUFFER_OVERFLOW and call expand on the buffer
                            //for now, throw an exception, as we initialized the buffers
                            //in the constructor
                            throw new IOException(sm.getString("channel.nio.ssl.unwrapFail", unwrap.getStatus()));
                        }
                    } while ((netInBuffer.position() != 0)); //continue to unwrapping as long as the input buffer has stuff
                    // If everything is OK, so complete
                    readPending = false;
                    handler.completed(Integer.valueOf(read), attach);
                } catch (Exception e) {
                    failed(e, attach);
                }
            }
        }
        @Override
        public void failed(Throwable exc, A attach) {
            readPending = false;
            handler.failed(exc, attach);
        }
    }

    @Override
    public <A> void read(final ByteBuffer dst,
            long timeout, TimeUnit unit, final A attachment,
            final CompletionHandler<Integer, ? super A> handler) {
        // Check state
        if (closing || closed) {
            handler.completed(Integer.valueOf(-1), attachment);
            return;
        }
        if (readPending) {
            throw new ReadPendingException();
        } else {
            readPending = true;
        }
        if (!handshakeComplete) {
            throw new IllegalStateException(sm.getString("channel.nio.ssl.incompleteHandshake"));
        }

        ReadCompletionHandler<A> readCompletionHandler = new ReadCompletionHandler<>(dst, handler);
        if (netInBuffer.position() > 0 ) {
            readCompletionHandler.completed(Integer.valueOf(netInBuffer.position()), attachment);
        } else {
            sc.read(netInBuffer, timeout, unit, attachment, readCompletionHandler);
        }
    }

    @Override
    public <A> void write(final ByteBuffer src, final long timeout, final TimeUnit unit,
            final A attachment, final CompletionHandler<Integer, ? super A> handler) {
        // Check state
        if (closing || closed) {
            handler.failed(new IOException(sm.getString("channel.nio.ssl.closing")), attachment);
            return;
        }
        if (writePending) {
            throw new WritePendingException();
        } else {
            writePending = true;
        }

        try {
            // Prepare the output buffer
            netOutBuffer.clear();
            // Wrap the source data into the internal buffer
            SSLEngineResult result = sslEngine.wrap(src, netOutBuffer);
            final int written = result.bytesConsumed();
            netOutBuffer.flip();
            if (result.getStatus() == Status.OK) {
                if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                    tasks();
                }
                // Write data to the channel
                sc.write(netOutBuffer, timeout, unit, attachment,
                        new CompletionHandler<Integer, A>() {
                    @Override
                    public void completed(Integer nBytes, A attach) {
                        if (nBytes.intValue() < 0) {
                            failed(new EOFException(), attach);
                        } else if (netOutBuffer.hasRemaining()) {
                            sc.write(netOutBuffer, timeout, unit, attachment, this);
                        } else if (written == 0) {
                            // Special case, start over to avoid code duplication
                            writePending = false;
                            write(src, timeout, unit, attachment, handler);
                        } else {
                            // Call the handler completed method with the
                            // consumed bytes number
                            writePending = false;
                            handler.completed(Integer.valueOf(written), attach);
                        }
                    }
                    @Override
                    public void failed(Throwable exc, A attach) {
                        writePending = false;
                        handler.failed(exc, attach);
                    }
                });
            } else {
                throw new IOException(sm.getString("channel.nio.ssl.wrapFail", result.getStatus()));
            }
        } catch (Exception e) {
            writePending = false;
            handler.failed(e, attachment);
        }
    }

    private class GatherState<A> {
        public ByteBuffer[] srcs;
        public int offset;
        public int length;
        public A attachment;
        public long timeout;
        public TimeUnit unit;
        public CompletionHandler<Long, ? super A> handler;
        protected GatherState(ByteBuffer[] srcs, int offset, int length,
                long timeout, TimeUnit unit, A attachment,
                CompletionHandler<Long, ? super A> handler) {
            this.srcs = srcs;
            this.offset = offset;
            this.length = length;
            this.timeout = timeout;
            this.unit = unit;
            this.attachment = attachment;
            this.handler = handler;
            this.pos = offset;
        }
        public long writeCount = 0;
        public int pos;
    }

    private class GatherCompletionHandler<A> implements CompletionHandler<Integer, GatherState<A>> {
        protected GatherState<A> state;
        protected GatherCompletionHandler(GatherState<A> state) {
            this.state = state;
        }
        @Override
        public void completed(Integer nBytes, GatherState<A> attachment) {
            if (nBytes.intValue() < 0) {
                failed(new EOFException(), attachment);
            } else {
                if (state.pos == state.offset + state.length) {
                    writePending = false;
                    state.handler.completed(Long.valueOf(state.writeCount), state.attachment);
                } else if (netOutBuffer.hasRemaining()) {
                    sc.write(netOutBuffer, state.timeout, state.unit, state, this);
                } else {
                    try {
                        // Prepare the output buffer
                        netOutBuffer.clear();
                        // Wrap the source data into the internal buffer
                        SSLEngineResult result = sslEngine.wrap(state.srcs[state.pos], netOutBuffer);
                        int written = result.bytesConsumed();
                        state.writeCount += written;
                        netOutBuffer.flip();
                        if (result.getStatus() == Status.OK) {
                            if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                                tasks();
                            }
                            if (!state.srcs[state.pos].hasRemaining()) {
                                state.pos++;
                            }
                            // Write data to the channel
                            sc.write(netOutBuffer, state.timeout, state.unit, state, this);
                        } else {
                            throw new IOException(sm.getString("channel.nio.ssl.wrapFail", result.getStatus()));
                        }
                    } catch (Exception e) {
                        failed(e, attachment);
                    }
                }
            }
        }
        @Override
        public void failed(Throwable exc, GatherState<A> attachment) {
            writePending = false;
            state.handler.failed(exc, state.attachment);
        }
    }

    @Override
    public <A> void write(ByteBuffer[] srcs, int offset, int length,
            long timeout, TimeUnit unit, A attachment,
            CompletionHandler<Long, ? super A> handler) {
        // Check state
        if ((offset < 0) || (length < 0) || (offset > srcs.length - length)) {
            throw new IndexOutOfBoundsException();
        }
        if (closing || closed) {
            handler.failed(new IOException(sm.getString("channel.nio.ssl.closing")), attachment);
            return;
        }
        if (writePending) {
            throw new WritePendingException();
        } else {
            writePending = true;
        }
        try {
            GatherState<A> state = new GatherState<>(srcs, offset, length,
                    timeout, unit, attachment, handler);
            // Prepare the output buffer
            netOutBuffer.clear();
            // Wrap the source data into the internal buffer
            SSLEngineResult result = sslEngine.wrap(srcs[offset], netOutBuffer);
            state.writeCount += result.bytesConsumed();
            netOutBuffer.flip();
            if (result.getStatus() == Status.OK) {
                if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                    tasks();
                }
                if (!srcs[offset].hasRemaining()) {
                    state.pos++;
                }
                // Write data to the channel
                sc.write(netOutBuffer, timeout, unit, state, new GatherCompletionHandler<>(state));
            } else {
                throw new IOException(sm.getString("channel.nio.ssl.wrapFail", result.getStatus()));
            }
        } catch (Exception e) {
            writePending = false;
            handler.failed(e, attachment);
        }
   }

    /**
     * Callback interface to be able to expand buffers
     * when buffer overflow exceptions happen
     */
    public static interface ApplicationBufferHandler {
        public ByteBuffer getReadBuffer();
        public ByteBuffer getWriteBuffer();
    }

    @Override
    public ApplicationBufferHandler getBufHandler() {
        return bufHandler;
    }

    @Override
    public boolean isHandshakeComplete() {
        return handshakeComplete;
    }

    @Override
    public boolean isClosing() {
        return closing;
    }

    public SSLEngine getSslEngine() {
        return sslEngine;
    }

    public ByteBuffer getEmptyBuf() {
        return emptyBuf;
    }

    public void setBufHandler(ApplicationBufferHandler bufHandler) {
        this.bufHandler = bufHandler;
    }

    @Override
    public AsynchronousSocketChannel getIOChannel() {
        return sc;
    }

}
TOP

Related Classes of org.apache.tomcat.util.net.SecureNio2Channel$ReadCompletionHandler

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.