Package com.facebook.nifty.core

Source Code of com.facebook.nifty.core.NiftyDispatcher$DispatcherContext

/*
* Copyright (C) 2012-2013 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.nifty.core;

import com.facebook.nifty.duplex.TDuplexProtocolFactory;
import com.facebook.nifty.duplex.TProtocolPair;
import com.facebook.nifty.duplex.TTransportPair;
import com.facebook.nifty.processor.NiftyProcessorFactory;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TMessageType;
import org.apache.thrift.protocol.TProtocol;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.util.Timeout;
import org.jboss.netty.util.Timer;
import org.jboss.netty.util.TimerTask;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static com.google.common.base.Preconditions.checkState;

/**
* Dispatch TNiftyTransport to the TProcessor and write output back.
*
* Note that all current async thrift clients are capable of sending multiple requests at once
* but not capable of handling out-of-order responses to those requests, so this dispatcher
* sends the requests in order. (Eventually this will be conditional on a flag in the thrift
* message header for future async clients that can handle out-of-order responses).
*/
public class NiftyDispatcher extends SimpleChannelUpstreamHandler
{
    private final NiftyProcessorFactory processorFactory;
    private final Executor exe;
    private final long taskTimeoutMillis;
    private final Timer taskTimeoutTimer;
    private final int queuedResponseLimit;
    private final Map<Integer, ThriftMessage> responseMap = new HashMap<>();
    private final AtomicInteger dispatcherSequenceId = new AtomicInteger(0);
    private final AtomicInteger lastResponseWrittenId = new AtomicInteger(0);
    private final TDuplexProtocolFactory duplexProtocolFactory;

    public NiftyDispatcher(ThriftServerDef def, Timer timer)
    {
        this.processorFactory = def.getProcessorFactory();
        this.duplexProtocolFactory = def.getDuplexProtocolFactory();
        this.queuedResponseLimit = def.getQueuedResponseLimit();
        this.exe = def.getExecutor();
        this.taskTimeoutMillis = (def.getTaskTimeout() == null ? 0 : def.getTaskTimeout().toMillis());
        this.taskTimeoutTimer = (def.getTaskTimeout() == null ? null : timer);
    }

    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
            throws Exception
    {
        if (e.getMessage() instanceof ThriftMessage) {
            ThriftMessage message = (ThriftMessage) e.getMessage();
            if (taskTimeoutMillis > 0) {
                message.setProcessStartTimeMillis(System.currentTimeMillis());
            }
            checkResponseOrderingRequirements(ctx, message);

            TNiftyTransport messageTransport = new TNiftyTransport(ctx.getChannel(), message);
            TTransportPair transportPair = TTransportPair.fromSingleTransport(messageTransport);
            TProtocolPair protocolPair = duplexProtocolFactory.getProtocolPair(transportPair);
            TProtocol inProtocol = protocolPair.getInputProtocol();
            TProtocol outProtocol = protocolPair.getOutputProtocol();

            processRequest(ctx, message, messageTransport, inProtocol, outProtocol);
        }
        else {
            ctx.sendUpstream(e);
        }
    }

    private void checkResponseOrderingRequirements(ChannelHandlerContext ctx, ThriftMessage message)
    {
        boolean messageRequiresOrderedResponses = message.isOrderedResponsesRequired();

        if (!DispatcherContext.isResponseOrderingRequirementInitialized(ctx)) {
            // This is the first request. This message will decide whether all responses on the
            // channel must be strictly ordered, or whether out-of-order is allowed.
            DispatcherContext.setResponseOrderingRequired(ctx, messageRequiresOrderedResponses);
        }
        else {
            // This is not the first request. Verify that the ordering requirement on this message
            // is consistent with the requirement on the channel itself.
            checkState(
                    messageRequiresOrderedResponses == DispatcherContext.isResponseOrderingRequired(ctx),
                    "Every message on a single channel must specify the same requirement for response ordering");
        }
    }

    private void processRequest(
            final ChannelHandlerContext ctx,
            final ThriftMessage message,
            final TNiftyTransport messageTransport,
            final TProtocol inProtocol,
            final TProtocol outProtocol) {
        // Remember the ordering of requests as they arrive, used to enforce an order on the
        // responses.
        final int requestSequenceId = dispatcherSequenceId.incrementAndGet();

        synchronized (responseMap)
        {
            // Limit the number of pending responses (responses which finished out of order, and are
            // waiting for previous requests to be finished so they can be written in order), by
            // blocking further channel reads. Due to the way Netty frame decoders work, this is more
            // of an estimate than a hard limit. Netty may continue to decode and process several
            // more requests that were in the latest read, even while further reads on the channel
            // have been blocked.
            if (requestSequenceId > lastResponseWrittenId.get() + queuedResponseLimit &&
                !DispatcherContext.isChannelReadBlocked(ctx))
            {
                DispatcherContext.blockChannelReads(ctx);
            }
        }

        try {
            exe.execute(new Runnable() {
                @Override
                public void run() {
                    ListenableFuture<Boolean> processFuture;
                    final AtomicBoolean responseSent = new AtomicBoolean(false);
                    // Use AtomicReference as a generic holder class to be able to mark it final
                    // and pass into inner classes. Since we only use .get() and .set(), we don't
                    // actually do any atomic operations.
                    final AtomicReference<Timeout> expireTimeout = new AtomicReference<>(null);

                    try {
                        try {
                            long timeRemaining = 0;
                            if (taskTimeoutMillis > 0) {
                                long timeElapsed = System.currentTimeMillis() - message.getProcessStartTimeMillis();
                                if (timeElapsed >= taskTimeoutMillis) {
                                    TApplicationException taskTimeoutException = new TApplicationException(
                                            TApplicationException.INTERNAL_ERROR,
                                            "Task stayed on the queue for " + timeElapsed +
                                                    " milliseconds, exceeding configured task timeout of " + taskTimeoutMillis +
                                                    " milliseconds."
                                    );
                                    sendTApplicationException(taskTimeoutException, ctx, message, requestSequenceId, messageTransport,
                                            inProtocol, outProtocol);
                                    return;
                                } else {
                                    timeRemaining = taskTimeoutMillis - timeElapsed;
                                }
                            }

                            if (timeRemaining > 0) {
                                expireTimeout.set(taskTimeoutTimer.newTimeout(new TimerTask() {
                                    @Override
                                    public void run(Timeout timeout) throws Exception {
                                        // The immediateFuture returned by processors isn't cancellable, cancel() and
                                        // isCanceled() always return false. Use a flag to detect task expiration.
                                        if (responseSent.compareAndSet(false, true)) {
                                            TApplicationException ex = new TApplicationException(
                                                    TApplicationException.INTERNAL_ERROR,
                                                    "Task timed out while executing."
                                            );
                                            // Create a temporary transport to send the exception
                                            ChannelBuffer duplicateBuffer = message.getBuffer().duplicate();
                                            duplicateBuffer.resetReaderIndex();
                                            TNiftyTransport temporaryTransport = new TNiftyTransport(
                                                    ctx.getChannel(),
                                                    duplicateBuffer,
                                                    message.getTransportType());
                                            TProtocolPair protocolPair = duplexProtocolFactory.getProtocolPair(
                                                    TTransportPair.fromSingleTransport(temporaryTransport));
                                            sendTApplicationException(ex, ctx, message,
                                                    requestSequenceId,
                                                    temporaryTransport,
                                                    protocolPair.getInputProtocol(),
                                                    protocolPair.getOutputProtocol());
                                        }
                                    }
                                }, timeRemaining, TimeUnit.MILLISECONDS));
                            }

                            ConnectionContext connectionContext = ConnectionContexts.getContext(ctx.getChannel());
                            RequestContext requestContext = new NiftyRequestContext(connectionContext, inProtocol, outProtocol, messageTransport);
                            RequestContexts.setCurrentContext(requestContext);
                            processFuture = processorFactory.getProcessor(messageTransport).process(inProtocol, outProtocol, requestContext);
                        } finally {
                            // RequestContext does NOT stay set while we are waiting for the process
                            // future to complete. This is by design because we'll might move on to the
                            // next request using this thread before this one is completed. If you need
                            // the context throughout an asynchronous handler, you need to read and store
                            // it before returning a future.
                            RequestContexts.clearCurrentContext();
                        }

                        Futures.addCallback(
                                processFuture,
                                new FutureCallback<Boolean>() {
                                    @Override
                                    public void onSuccess(Boolean result) {
                                        deleteExpirationTimer(expireTimeout.get());
                                        try {
                                            // Only write response if the client is still there and the task timeout
                                            // hasn't expired.
                                            if (ctx.getChannel().isConnected() && responseSent.compareAndSet(false, true)) {
                                                ThriftMessage response = message.getMessageFactory().create(
                                                        messageTransport.getOutputBuffer());
                                                writeResponse(ctx, response, requestSequenceId,
                                                        DispatcherContext.isResponseOrderingRequired(ctx));
                                            }
                                        } catch (Throwable t) {
                                            onDispatchException(ctx, t);
                                        }
                                    }

                                    @Override
                                    public void onFailure(Throwable t) {
                                        deleteExpirationTimer(expireTimeout.get());
                                        onDispatchException(ctx, t);
                                    }
                                }
                        );
                    } catch (TException e) {
                        onDispatchException(ctx, e);
                    }
                }
            });
        }
        catch (RejectedExecutionException ex) {
            TApplicationException x = new TApplicationException(TApplicationException.INTERNAL_ERROR,
                    "Server overloaded");
            sendTApplicationException(x, ctx, message, requestSequenceId, messageTransport, inProtocol, outProtocol);
        }
    }

    private void deleteExpirationTimer(Timeout timeout)
    {
        if (timeout == null) {
            return;
        }
        timeout.cancel();
    }

    private void sendTApplicationException(
            TApplicationException x,
            ChannelHandlerContext ctx,
            ThriftMessage request,
            int responseSequenceId,
            TNiftyTransport requestTransport,
            TProtocol inProtocol,
            TProtocol outProtocol)
    {
        if (ctx.getChannel().isConnected()) {
            try {
                TMessage message = inProtocol.readMessageBegin();
                outProtocol.writeMessageBegin(new TMessage(message.name, TMessageType.EXCEPTION, message.seqid));
                x.write(outProtocol);
                outProtocol.writeMessageEnd();
                outProtocol.getTransport().flush();

                ThriftMessage response = request.getMessageFactory().create(requestTransport.getOutputBuffer());
                writeResponse(ctx, response, responseSequenceId, DispatcherContext.isResponseOrderingRequired(ctx));
            }
            catch (TException ex) {
                onDispatchException(ctx, ex);
            }
        }
    }

    private void onDispatchException(ChannelHandlerContext ctx, Throwable t)
    {
        Channels.fireExceptionCaught(ctx, t);
        closeChannel(ctx);
    }

    private void writeResponse(ChannelHandlerContext ctx,
                               ThriftMessage response,
                               int responseSequenceId,
                               boolean isOrderedResponsesRequired)
    {
        if (isOrderedResponsesRequired) {
            writeResponseInOrder(ctx, response, responseSequenceId);
        }
        else {
            // No ordering required, just write the response immediately
            Channels.write(ctx.getChannel(), response);
            lastResponseWrittenId.incrementAndGet();
        }
    }

    private void writeResponseInOrder(ChannelHandlerContext ctx,
                                      ThriftMessage response,
                                      int responseSequenceId)
    {
        // Ensure responses to requests are written in the same order the requests
        // were received.
        synchronized (responseMap) {
            int currentResponseId = lastResponseWrittenId.get() + 1;
            if (responseSequenceId != currentResponseId) {
                // This response is NOT next in line of ordered responses, save it to
                // be sent later, after responses to all earlier requests have been
                // sent.
                responseMap.put(responseSequenceId, response);
            } else {
                // This response was next in line, write this response now, and see if
                // there are others next in line that should be sent now as well.
                do {
                    Channels.write(ctx.getChannel(), response);
                    lastResponseWrittenId.incrementAndGet();
                    ++currentResponseId;
                    response = responseMap.remove(currentResponseId);
                } while (null != response);

                // Now that we've written some responses, check if reads should be unblocked
                if (DispatcherContext.isChannelReadBlocked(ctx)) {
                    int lastRequestSequenceId = dispatcherSequenceId.get();
                    if (lastRequestSequenceId <= lastResponseWrittenId.get() + queuedResponseLimit) {
                        DispatcherContext.unblockChannelReads(ctx);
                    }
                }
            }
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e)
            throws Exception
    {
        // Any out of band exception are caught here and we tear down the socket
        closeChannel(ctx);

        // Send for logging
        ctx.sendUpstream(e);
    }

    private void closeChannel(ChannelHandlerContext ctx)
    {
        if (ctx.getChannel().isOpen()) {
            ctx.getChannel().close();
        }
    }

    @Override
    public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        // Reads always start out unblocked
        DispatcherContext.unblockChannelReads(ctx);
        super.channelOpen(ctx, e);
    }

    private static class DispatcherContext
    {
        private ReadBlockedState readBlockedState = ReadBlockedState.NOT_BLOCKED;
        private boolean responseOrderingRequired = false;
        private boolean responseOrderingRequirementInitialized = false;

        public static boolean isChannelReadBlocked(ChannelHandlerContext ctx) {
            return getDispatcherContext(ctx).readBlockedState == ReadBlockedState.BLOCKED;
        }

        public static void blockChannelReads(ChannelHandlerContext ctx) {
            // Remember that reads are blocked (there is no Channel.getReadable())
            getDispatcherContext(ctx).readBlockedState = ReadBlockedState.BLOCKED;

            // NOTE: this shuts down reads, but isn't a 100% guarantee we won't get any more messages.
            // It sets up the channel so that the polling loop will not report any new read events
            // and netty won't read any more data from the socket, but any messages already fully read
            // from the socket before this ran may still be decoded and arrive at this handler. Thus
            // the limit on queued messages before we block reads is more of a guidance than a hard
            // limit.
            ctx.getChannel().setReadable(false);
        }

        public static void unblockChannelReads(ChannelHandlerContext ctx) {
            // Remember that reads are unblocked (there is no Channel.getReadable())
            getDispatcherContext(ctx).readBlockedState = ReadBlockedState.NOT_BLOCKED;
            ctx.getChannel().setReadable(true);
        }

        public static void setResponseOrderingRequired(ChannelHandlerContext ctx, boolean required)
        {
            DispatcherContext dispatcherContext = getDispatcherContext(ctx);
            dispatcherContext.responseOrderingRequirementInitialized = true;
            dispatcherContext.responseOrderingRequired = required;
        }

        public static boolean isResponseOrderingRequired(ChannelHandlerContext ctx)
        {
            return getDispatcherContext(ctx).responseOrderingRequired;
        }

        public static boolean isResponseOrderingRequirementInitialized(ChannelHandlerContext ctx)
        {
            return getDispatcherContext(ctx).responseOrderingRequirementInitialized;
        }

        private static DispatcherContext getDispatcherContext(ChannelHandlerContext ctx)
        {
            DispatcherContext dispatcherContext;
            Object attachment = ctx.getAttachment();

            if (attachment == null) {
                // No context was added yet, add one
                dispatcherContext = new DispatcherContext();
                ctx.setAttachment(dispatcherContext);
            }
            else if (!(attachment instanceof DispatcherContext)) {
                // There was a context, but it was the wrong type. This should never happen.
                throw new IllegalStateException("NiftyDispatcher handler context should be of type NiftyDispatcher.DispatcherContext");
            }
            else {
                dispatcherContext = (DispatcherContext) attachment;
            }

            return dispatcherContext;
        }

        private enum ReadBlockedState {
            NOT_BLOCKED,
            BLOCKED,
        }
    }
}
TOP

Related Classes of com.facebook.nifty.core.NiftyDispatcher$DispatcherContext

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.