Package io.undertow.client.ajp

Source Code of io.undertow.client.ajp.AjpClientRequestConduit

/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
*  Unless required by applicable law or agreed to in writing, software
*  distributed under the License is distributed on an "AS IS" BASIS,
*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*  See the License for the specific language governing permissions and
*  limitations under the License.
*/

package io.undertow.client.ajp;

import io.undertow.client.ClientRequest;
import io.undertow.client.ProxiedRequestAttachments;
import io.undertow.client.UndertowClientMessages;
import io.undertow.conduits.ConduitListener;
import io.undertow.util.FlexBase64;
import io.undertow.util.HeaderMap;
import io.undertow.util.HttpString;
import org.xnio.IoUtils;
import org.xnio.Pool;
import org.xnio.Pooled;
import org.xnio.channels.FixedLengthUnderflowException;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.AbstractStreamSinkConduit;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;

import static io.undertow.client.ajp.AjpConstants.ATTR_AUTH_TYPE;
import static io.undertow.client.ajp.AjpConstants.ATTR_QUERY_STRING;
import static io.undertow.client.ajp.AjpConstants.ATTR_REMOTE_USER;
import static io.undertow.client.ajp.AjpConstants.ATTR_ROUTE;
import static io.undertow.client.ajp.AjpConstants.ATTR_SECRET;
import static io.undertow.client.ajp.AjpConstants.ATTR_SSL_CERT;
import static io.undertow.client.ajp.AjpConstants.ATTR_SSL_CIPHER;
import static io.undertow.client.ajp.AjpConstants.ATTR_SSL_KEY_SIZE;
import static io.undertow.client.ajp.AjpConstants.ATTR_SSL_SESSION;
import static io.undertow.client.ajp.AjpConstants.ATTR_STORED_METHOD;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.Bits.anyAreSet;
import static org.xnio.Bits.longBitMask;

/**
* AJP client request channel. For now we are going to assume that the buffers are sized to
* fit complete packets. As AJP packets are limited to 8k this is a reasonable assumption.
*
* @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
* @author Stuart Douglas
*/
final class AjpClientRequestConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {

    private static final int MAX_DATA_SIZE = 8186;

    private final Pool<ByteBuffer> pool;

    /**
     * The current data buffer. This will be released once it has been written out.
     */
    private Pooled<ByteBuffer> currentDataBuffer;

    /**
     * header buffer for the current chunk, if it was not written out
     */
    private ByteBuffer headerDataBuffer;

    private final AjpClientExchange exchange;

    private final ConduitListener<? super AjpClientRequestConduit> finishListener;

    private final boolean hasContent;
    /**
     * State flags, with the chunk remaining stored in the low bytes
     */
    private long state;

    private long totalRemaining;

    private int requestedChunkSize = -1;

    /**
     * The remaining bits are used to store the remaining chunk size.
     */
    private static final long STATE_MASK = longBitMask(0, 57);

    private static final long FLAG_START = 1L << 63L; //indicates that the header has not been generated yet.
    private static final long FLAG_SHUTDOWN = 1L << 62L;
    private static final long FLAG_DELEGATE_SHUTDOWN = 1L << 61L;
    private static final long FLAG_WRITES_RESUMED = 1L << 60L;
    private static final long FLAG_FINAL_CHUNK_GENERATED = 1L << 59L;
    private static final long FLAG_DISCARD = 1L << 58L;


    AjpClientRequestConduit(final StreamSinkConduit next, final Pool<ByteBuffer> pool, final AjpClientExchange exchange, ConduitListener<? super AjpClientRequestConduit> finishListener, long size) {
        super(next);
        this.pool = pool;
        this.exchange = exchange;
        this.finishListener = finishListener;
        this.hasContent = size != 0;
        this.totalRemaining = size;
        state = FLAG_START;

        if (hasContent) {
            if (size > 0) {
                //fixed length
                requestedChunkSize = MAX_DATA_SIZE;
            } else {
                requestedChunkSize = 0;
            }
        }

    }

    private static void putInt(final ByteBuffer buf, int value) {
        buf.put((byte) ((value >> 8) & 0xFF));
        buf.put((byte) (value & 0xFF));
    }

    private static void putString(final ByteBuffer buf, String value) {
        final int length = value.length();
        putInt(buf, length);
        for (int i = 0; i < length; ++i) {
            buf.put((byte) value.charAt(i));
        }
        buf.put((byte) 0);
    }

    private static void putHttpString(final ByteBuffer buf, HttpString value) {
        final int length = value.length();
        putInt(buf, length);
        value.appendTo(buf);
        buf.put((byte) 0);
    }

    /**
     * Called when the target requests a body chunk
     * @param requestedSize The size of the requested chunk
     */
    void setBodyChunkRequested(int requestedSize) {
        this.requestedChunkSize = requestedSize;
        if (anyAreSet(state, FLAG_WRITES_RESUMED)) {
            next.resumeWrites();
        }
    }

    /**
     * Called then the request is done. This means no more chunks will be forthcoming,
     * and if the request has not been full written then the channel is closed.
     */
    void setRequestDone() {
        state |= FLAG_DISCARD;
        if (anyAreSet(state, FLAG_WRITES_RESUMED)) {
            next.wakeupWrites();
        }
    }

    /**
     * Handles writing out the header data, plus any current buffers. Returns true if the write can proceed,
     * false if there are still cached buffers
     *
     * @throws java.io.IOException
     */
    private boolean processWrite() throws IOException {
        if (anyAreSet(state, FLAG_DELEGATE_SHUTDOWN)) {
            return true;
        }

        //if currentDataBuffer is set then we just
        if (anyAreSet(state, FLAG_START)) {
            this.state &= ~FLAG_START;

            final ClientRequest request = exchange.getRequest();
            final String path;
            final String queryString;
            int qsIndex = exchange.getRequest().getPath().indexOf('?');
            if (qsIndex == -1) {
                path = exchange.getRequest().getPath();
                queryString = null;
            } else {
                path = exchange.getRequest().getPath().substring(0, qsIndex);
                queryString = exchange.getRequest().getPath().substring(qsIndex + 1);
            }

            currentDataBuffer = pool.allocate();
            final ByteBuffer buffer = currentDataBuffer.getResource();
            buffer.put((byte) 0x12);
            buffer.put((byte) 0x34);
            buffer.put((byte) 0); //we fill the size in later
            buffer.put((byte) 0);
            buffer.put((byte) 2);
            boolean storeMethod = false;
            Integer methodNp = AjpConstants.HTTP_METHODS_MAP.get(request.getMethod());
            if (methodNp == null) {
                methodNp = 0xFF;
                storeMethod = true;
            }
            buffer.put((byte) (int) methodNp);
            putHttpString(buffer, exchange.getRequest().getProtocol());
            putString(buffer, path);
            putString(buffer, notNull(request.getAttachment(ProxiedRequestAttachments.REMOTE_ADDRESS)));
            putString(buffer, notNull(request.getAttachment(ProxiedRequestAttachments.REMOTE_HOST)));
            putString(buffer, notNull(request.getAttachment(ProxiedRequestAttachments.SERVER_NAME)));
            putInt(buffer, notNull(request.getAttachment(ProxiedRequestAttachments.SERVER_PORT)));
            buffer.put((byte) (notNull(request.getAttachment(ProxiedRequestAttachments.IS_SSL)) ? 1 : 0));

            int headers = 0;
            //we need to count the headers
            final HeaderMap responseHeaders = request.getRequestHeaders();
            for (HttpString name : responseHeaders.getHeaderNames()) {
                headers += responseHeaders.get(name).size();
            }

            putInt(buffer, headers);


            for (final HttpString header : responseHeaders.getHeaderNames()) {
                for (String headerValue : responseHeaders.get(header)) {
                    Integer headerCode = AjpConstants.HEADER_MAP.get(header);
                    if (headerCode != null) {
                        putInt(buffer, headerCode);
                    } else {
                        putHttpString(buffer, header);
                    }
                    putString(buffer, headerValue);
                }
            }

            if (queryString != null) {
                buffer.put((byte) ATTR_QUERY_STRING); //query_string
                putString(buffer, queryString);
            }
            String remoteUser = request.getAttachment(ProxiedRequestAttachments.REMOTE_USER);
            if(remoteUser != null) {
                buffer.put((byte) ATTR_REMOTE_USER);
                putString(buffer, remoteUser);
            }
            String authType = request.getAttachment(ProxiedRequestAttachments.AUTH_TYPE);
            if(authType != null) {
                buffer.put((byte) ATTR_AUTH_TYPE);
                putString(buffer, authType);
            }
            String route = request.getAttachment(ProxiedRequestAttachments.ROUTE);
            if(route != null) {
                buffer.put((byte) ATTR_ROUTE);
                putString(buffer, route);
            }
            String sslCert = request.getAttachment(ProxiedRequestAttachments.SSL_CERT);
            if(sslCert != null) {
                buffer.put((byte) ATTR_SSL_CERT);
                putString(buffer, sslCert);
            }
            String sslCypher = request.getAttachment(ProxiedRequestAttachments.SSL_CYPHER);
            if(sslCypher != null) {
                buffer.put((byte) ATTR_SSL_CIPHER);
                putString(buffer, sslCypher);
            }
            byte[] sslSession = request.getAttachment(ProxiedRequestAttachments.SSL_SESSION_ID);
            if(sslSession != null) {
                buffer.put((byte) ATTR_SSL_SESSION);
                putString(buffer, FlexBase64.encodeString(sslSession, false));
            }
            Integer sslKeySize = request.getAttachment(ProxiedRequestAttachments.SSL_KEY_SIZE);
            if(sslKeySize != null) {
                buffer.put((byte) ATTR_SSL_KEY_SIZE);
                putString(buffer, sslKeySize.toString());
            }
            String secret = request.getAttachment(ProxiedRequestAttachments.SECRET);
            if(secret != null) {
                buffer.put((byte) ATTR_SECRET);
                putString(buffer, secret);
            }

            if(storeMethod) {
                buffer.put((byte) ATTR_STORED_METHOD);
                putString(buffer, request.getMethod().toString());
            }
            buffer.put((byte) 0xFF);

            int dataLength = buffer.position() - 4;
            buffer.put(2, (byte) ((dataLength >> 8) & 0xFF));
            buffer.put(3, (byte) (dataLength & 0xFF));
            buffer.flip();
            if (!hasContent) {
                this.state |= FLAG_SHUTDOWN;
            }
        }

        if (currentDataBuffer != null) {
            if (!writeCurrentBuffer()) {
                return false;
            }
        }
        return true;
    }

    /**
     * generates a final chunk for non fixed length requests
     *
     * @return
     * @throws IOException
     */
    private boolean handleFinalChunk() throws IOException {
        if (!hasContent) {
            return true;
        }
        if (anyAreSet(state, FLAG_SHUTDOWN) && !anyAreSet(state, FLAG_FINAL_CHUNK_GENERATED)) {
            state |= FLAG_FINAL_CHUNK_GENERATED;

            if (totalRemaining < 0) {
                byte[] header = new byte[6];
                header[0] = (byte) 0x12;
                header[1] = (byte) 0x34;
                header[2] = (byte) (0 & 0xFF);
                header[3] = (byte) (2 & 0xFF);
                header[4] = (byte) (0 & 0xFF);
                header[5] = (byte) (0 & 0xFF);
                ByteBuffer buffer = ByteBuffer.wrap(header);
                headerDataBuffer = buffer;
            }
        }
        if (headerDataBuffer != null) {

            ByteBuffer buffer = headerDataBuffer;
            int r;
            do {
                r = next.write(buffer);
                if (r == 0) {
                    return false;
                }
            } while (buffer.hasRemaining());
            headerDataBuffer = null;
            return true;
        }
        return true;
    }

    private boolean notNull(Boolean attachment) {
        return attachment == null ? false : attachment;
    }

    private int notNull(Integer attachment) {
        return attachment == null ? 0 : attachment;
    }

    private String notNull(String attachment) {
        return attachment == null ? "" : attachment;
    }

    private boolean writeCurrentBuffer() throws IOException {
        ByteBuffer buffer = currentDataBuffer.getResource();
        int r;
        do {
            r = next.write(buffer);
            if (r == 0) {
                return false;
            }
        } while (buffer.hasRemaining());
        currentDataBuffer.free();
        currentDataBuffer = null;
        return true;
    }


    public int write(final ByteBuffer src) throws IOException {
        if(anyAreSet(state, FLAG_DISCARD)) {
            int ret = src.remaining();
            src.position(src.limit());
            totalRemaining-=ret;
            return ret;
        }
        if(anyAreSet(state, FLAG_SHUTDOWN)) {
            throw new ClosedChannelException();
        }
        if (!processWrite()) {
            return 0;
        }
        if (src.remaining() == 0) {
            return 0;
        }
        long remaining = state & STATE_MASK;

        if (remaining == 0 && requestedChunkSize <= 0) {
            next.suspendWrites();
            return 0;
        }
        if (remaining == 0) {
            headerDataBuffer = createHeader(src);
            requestedChunkSize = 0;
            remaining = state & STATE_MASK; //this is a bit yuck
        }
        int limit = src.limit();
        if (src.remaining() > remaining) {
            src.limit((int) (src.position() + remaining));
        }
        try {
            ByteBuffer[] bufs;
            int headerLength = 0;
            if (src.remaining() == remaining) {
                if (headerDataBuffer == null) {
                    bufs = new ByteBuffer[]{src};
                } else {
                    bufs = new ByteBuffer[]{headerDataBuffer, src};
                    headerLength = headerDataBuffer.remaining();
                }
            } else {
                if (headerDataBuffer == null) {
                    bufs = new ByteBuffer[]{src};
                } else {
                    bufs = new ByteBuffer[]{headerDataBuffer, src};
                    headerLength = headerDataBuffer.remaining();
                }
            }
            int r = (int) next.write(bufs, 0, bufs.length);
            r -= headerLength;
            if(!headerDataBuffer.hasRemaining()) {
                headerDataBuffer = null;
            }
            if (r > 0) {
                remaining -= r;
                if (remaining < 0) {
                    remaining = 0;
                    r -= 1;
                }
                if(totalRemaining > 0) {
                    totalRemaining -= r;
                }
                return r;
            } else {
                return 0;
            }
        } finally {
            src.limit(limit);
            this.state = (state & ~STATE_MASK) | remaining;
        }
    }

    private ByteBuffer createHeader(final ByteBuffer src) {
        int remaining = src.remaining();
        remaining = Math.min(remaining, MAX_DATA_SIZE);
        remaining = Math.min(remaining, requestedChunkSize);
        int bodySize = remaining + 3;
        byte[] header = new byte[6];
        header[0] = (byte) 0x12;
        header[1] = (byte) 0x34;
        header[2] = (byte) ((bodySize >> 8) & 0xFF);
        header[3] = (byte) (bodySize & 0xFF);
        header[4] = (byte) ((remaining >> 8) & 0xFF);
        header[5] = (byte) (remaining & 0xFF);
        this.state = (state & ~STATE_MASK) | remaining;
        return ByteBuffer.wrap(header);
    }

    public long write(final ByteBuffer[] srcs) throws IOException {
        return write(srcs, 0, srcs.length);
    }

    public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
        long total = 0;
        for (int i = offset; i < offset + length; ++i) {
            while (srcs[i].hasRemaining()) {
                int written = write(srcs[i]);
                if (written <= 0 && total == 0) {
                    return written;
                } else if (written <= 0) {
                    return total;
                }
                total += written;
            }
        }
        return total;
    }

    @Override
    public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
        return Conduits.writeFinalBasic(this, srcs, offset, length);
    }

    @Override
    public int writeFinal(ByteBuffer src) throws IOException {
        return Conduits.writeFinalBasic(this, src);
    }

    public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
        return src.transferTo(position, count, new ConduitWritableByteChannel(this));
    }

    public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
        return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
    }

    public boolean flush() throws IOException {
        long state = this.state;
        boolean discard = anyAreSet(state, FLAG_DISCARD);
        if(!discard) {
            if (!processWrite()) {
                return false;
            }
        }
        if (allAreClear(state, FLAG_SHUTDOWN)) {
            return next.flush();
        }
        if(!discard) {
            if (!handleFinalChunk()) {
                return false;
            }
        }
        if (allAreSet(state, FLAG_SHUTDOWN) && allAreClear(state, FLAG_DELEGATE_SHUTDOWN)) {
            if (finishListener != null) {
                finishListener.handleEvent(this);
            }
            this.state |= FLAG_DELEGATE_SHUTDOWN;
        }
        return next.flush();
    }

    public void suspendWrites() {
        state &= ~FLAG_WRITES_RESUMED;
        next.suspendWrites();
    }

    public void resumeWrites() {
        state |= FLAG_WRITES_RESUMED;
        long remaining = state & STATE_MASK;
        if (remaining != 0 || requestedChunkSize != 0) {
            next.resumeWrites();
        }
    }

    public boolean isWriteResumed() {
        return anyAreSet(state, FLAG_WRITES_RESUMED);
    }

    public void wakeupWrites() {
        state |= FLAG_WRITES_RESUMED;
        next.wakeupWrites();
    }

    public void terminateWrites() throws IOException {
        long remaining = state & STATE_MASK;
        if (remaining != 0) {
            try {
                throw UndertowClientMessages.MESSAGES.dataStillRemainingInChunk(remaining);
            } finally {
                next.truncateWrites();
            }
        }
        if (totalRemaining > 0) {
            try {
                throw new FixedLengthUnderflowException(totalRemaining + " bytes remaining");
            } finally {
                next.truncateWrites();
            }
        }
        long state = this.state;
        if (anyAreSet(state, FLAG_SHUTDOWN)) {
            return;
        }
        this.state |= FLAG_SHUTDOWN;
    }

    public void awaitWritable() throws IOException {
        throw new IllegalStateException();
    }

    public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
        throw new IllegalStateException();
    }
}
TOP

Related Classes of io.undertow.client.ajp.AjpClientRequestConduit

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.