Package org.apache.mina.filter.traffic

Source Code of org.apache.mina.filter.traffic.ReadThrottleFilter

/*
*  Licensed to the Apache Software Foundation (ASF) under one
*  or more contributor license agreements.  See the NOTICE file
*  distributed with this work for additional information
*  regarding copyright ownership.  The ASF licenses this file
*  to you under the Apache License, Version 2.0 (the
*  "License"); you may not use this file except in compliance
*  with the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
*  Unless required by applicable law or agreed to in writing,
*  software distributed under the License is distributed on an
*  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*  KIND, either express or implied.  See the License for the
*  specific language governing permissions and limitations
*  under the License.
*
*/
package org.apache.mina.filter.traffic;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.mina.common.AttributeKey;
import org.apache.mina.common.IoBuffer;
import org.apache.mina.common.IoFilter;
import org.apache.mina.common.IoFilterAdapter;
import org.apache.mina.common.IoFilterChain;
import org.apache.mina.common.IoService;
import org.apache.mina.common.IoSession;
import org.apache.mina.common.TrafficMask;
import org.apache.mina.filter.executor.ExecutorFilter;
import org.apache.mina.util.CopyOnWriteMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* An {@link IoFilter} that throttles incoming traffic to
* prevent a unwanted {@link OutOfMemoryError} under heavy load.
* <p>
* This filter will automatically disable reads on an {@link IoSession} once
* the amount of the read data batched for that session in the {@link ExecutorFilter}
* reaches a defined threshold. It accomplishes this by adding one filter before the
* {@link ExecutorFilter}.
* <p>
* The size of the received data is calculated by {@link MessageSizeEstimator}.
* If you are using a transport whose envelope is not an {@link IoBuffer},
* you could write your own {@link MessageSizeEstimator} for better traffic
* calculation.  However, the {@link DefaultMessageSizeEstimator} will suffice
* in most cases.
* <p>
* It is recommended to add this filter at the end of your filter chain
* configuration because it is possible to subvert the behavior of the added
* filters by adding a filter immediately before/after the {@link ExecutorFilter}
* after inserting this builder, consequently leading to a unexpected behavior.
*
* @author The Apache MINA Project (dev@mina.apache.org)
* @version $Rev: 616100 $, $Date: 2008-01-28 15:58:32 -0700 (Mon, 28 Jan 2008) $
*/
public class ReadThrottleFilter extends IoFilterAdapter {
   
    private static final AtomicInteger globalBufferSize = new AtomicInteger();
    private static final Map<IoService, AtomicInteger> serviceBufferSizes =
        new CopyOnWriteMap<IoService, AtomicInteger>();
   
    private static final Object globalResumeLock = new Object();
    private static long lastGlobalResumeTime = 0;
    private final Logger logger = LoggerFactory.getLogger(getClass());

    /**
     * Returns the current amount of data in the buffer of the {@link ExecutorFilter}
     * for all {@link IoSession} whose {@link IoFilterChain} has been configured by
     * this builder.
     */
    public static int getGlobalBufferSize() {
        return globalBufferSize.get();
    }
   
    public static int getServiceBufferSize(IoService service) {
        AtomicInteger answer = serviceBufferSizes.get(service);
        if (answer == null) {
            return 0;
        } else {
            return answer.get();
        }
    }
   
    private static int increaseServiceBufferSize(IoService service, int increment) {
        AtomicInteger serviceBufferSize = serviceBufferSizes.get(service);
        if (serviceBufferSize == null) {
            synchronized (serviceBufferSizes) {
                serviceBufferSize = serviceBufferSizes.get(service);
                if (serviceBufferSize == null) {
                    serviceBufferSize = new AtomicInteger(increment);
                    serviceBufferSizes.put(service, serviceBufferSize);
                    return increment;
                }
            }
        }
        return serviceBufferSize.addAndGet(increment);
    }
   
    private final AttributeKey STATE =
        new AttributeKey(ReadThrottleFilter.class, "state");

    private volatile ReadThrottlePolicy policy;
    private final MessageSizeEstimator messageSizeEstimator;
   
    private volatile int maxSessionBufferSize;
    private volatile int maxServiceBufferSize;
    private volatile int maxGlobalBufferSize;
   
    private final IoFilter enterFilter = new EnterFilter();
   
    private final ScheduledExecutorService executor;
    private ScheduledFuture<?> resumeOthersFuture;
    private final AtomicInteger sessionCount = new AtomicInteger();
    private final Runnable resumeOthersTask = new Runnable() {
        public void run() {
            resumeOthers();
        }
    };

    /**
     * Creates a new instance with 64KB <tt>maxSessionBufferSize</tt>,
     * 128MB <tt>maxGlobalBufferSize</tt> and a new {@link DefaultMessageSizeEstimator}.
     */
    public ReadThrottleFilter(ScheduledExecutorService executor) {
        this(executor, ReadThrottlePolicy.LOG);
    }
   
    public ReadThrottleFilter(
            ScheduledExecutorService executor, ReadThrottlePolicy policy) {
        this(executor, policy, null);
    }
   
    public ReadThrottleFilter(
            ScheduledExecutorService executor,
            ReadThrottlePolicy policy, MessageSizeEstimator messageSizeEstimator) {
        // 64KB, 64MB, 128MB.
        this(executor, policy, messageSizeEstimator, 65536, 1048576 * 64, 1048576 * 128);
    }
   
    /**
     * Creates a new instance with the specified <tt>maxSessionBufferSize</tt>,
     * <tt>maxGlobalBufferSize</tt> and a new {@link DefaultMessageSizeEstimator}.
     */
    public ReadThrottleFilter(
            ScheduledExecutorService executor,
            int maxSessionBufferSize, int maxServiceBufferSize, int maxGlobalBufferSize) {
        this(executor, ReadThrottlePolicy.LOG, maxSessionBufferSize, maxServiceBufferSize, maxGlobalBufferSize);
    }

    public ReadThrottleFilter(
            ScheduledExecutorService executor, ReadThrottlePolicy policy,
            int maxSessionBufferSize, int maxServiceBufferSize, int maxGlobalBufferSize) {
        this(executor, policy, null, maxSessionBufferSize, maxServiceBufferSize, maxGlobalBufferSize);
    }

    /**
     * Creates a new instance with the specified <tt>maxSessionBufferSize</tt>,
     * <tt>maxGlobalBufferSize</tt> and {@link MessageSizeEstimator}.
     *
     * @param maxSessionBufferSize the maximum amount of data in the buffer of
     *                           the {@link ExecutorFilter} per {@link IoSession}.
     *                           Specify {@code 0} or a smaller value to disable.
     * @param maxGlobalBufferSize the maximum amount of data in the buffer of
     *                            the {@link ExecutorFilter} for all {@link IoSession}
     *                            whose {@link IoFilterChain} has been configured by
     *                            this builder.
     *                            Specify {@code 0} or a smaller value to disable.
     * @param messageSizeEstimator the message size estimator. If {@code null},
     *                             a new {@link DefaultMessageSizeEstimator} is created.
     */
    public ReadThrottleFilter(
            ScheduledExecutorService executor,
            ReadThrottlePolicy policy, MessageSizeEstimator messageSizeEstimator,
            int maxSessionBufferSize, int maxServiceBufferSize, int maxGlobalBufferSize) {
        if (messageSizeEstimator == null) {
            messageSizeEstimator = new DefaultMessageSizeEstimator();
        }
        this.executor = executor;
        this.messageSizeEstimator = messageSizeEstimator;
        setPolicy(policy);
        setMaxSessionBufferSize(maxSessionBufferSize);
        setMaxServiceBufferSize(maxServiceBufferSize);
        setMaxGlobalBufferSize(maxGlobalBufferSize);
    }

    public ReadThrottlePolicy getPolicy() {
        return policy;
    }

    public void setPolicy(ReadThrottlePolicy policy) {
        if (policy == null) {
            throw new NullPointerException("policy");
        }
       
        this.policy = policy;
    }

    /**
     * Returns the maximum amount of data in the buffer of the {@link ExecutorFilter}
     * per {@link IoSession}.  {@code 0} means 'disabled'.
     */
    public int getMaxSessionBufferSize() {
        return maxSessionBufferSize;
    }
   
    public int getMaxServiceBufferSize() {
        return maxServiceBufferSize;
    }
   
    /**
     * Returns the maximum amount of data in the buffer of the {@link ExecutorFilter}
     * for all {@link IoSession} whose {@link IoFilterChain} has been configured by
     * this builder. {@code 0} means 'disabled'.
     */
    public int getMaxGlobalBufferSize() {
        return maxGlobalBufferSize;
    }
   
    /**
     * Sets the maximum amount of data in the buffer of the {@link ExecutorFilter}
     * per {@link IoSession}.  Specify {@code 0} or a smaller value to disable.
     */
    public void setMaxSessionBufferSize(int maxSessionBufferSize) {
        if (maxSessionBufferSize < 0) {
            maxSessionBufferSize = 0;
        }
        this.maxSessionBufferSize = maxSessionBufferSize;
    }

    public void setMaxServiceBufferSize(int maxServiceBufferSize) {
        if (maxServiceBufferSize < 0) {
            maxServiceBufferSize = 0;
        }
        this.maxServiceBufferSize = maxServiceBufferSize;
    }

    /**
     * Sets the maximum amount of data in the buffer of the {@link ExecutorFilter}
     * for all {@link IoSession} whose {@link IoFilterChain} has been configured by
     * this builder. Specify {@code 0} or a smaller value to disable.
     */
    public void setMaxGlobalBufferSize(int maxGlobalBufferSize) {
        if (maxGlobalBufferSize < 0) {
            maxGlobalBufferSize = 0;
        }
        this.maxGlobalBufferSize = maxGlobalBufferSize;
    }
   
    /**
     * Returns the size estimator currently in use.
     */
    public MessageSizeEstimator getMessageSizeEstimator() {
        return messageSizeEstimator;
    }
   
    /**
     * Returns the current amount of data in the buffer of the {@link ExecutorFilter}
     * for the specified {@link IoSession}.
     */
    public int getSessionBufferSize(IoSession session) {
        State state = (State) session.getAttribute(STATE);
        if (state == null) {
            return 0;
        }
       
        synchronized (state) {
            return state.sessionBufferSize;
        }
    }
   
    @Override
    public void onPreAdd(
            IoFilterChain parent, String name, NextFilter nextFilter) throws Exception {
        if (!parent.contains(ExecutorFilter.class)) {
            throw new IllegalStateException(
                    "At least one " + ExecutorFilter.class.getName() + " must exist in the chain.");
        }
        if (parent.contains(this)) {
            throw new IllegalArgumentException(
                    "You can't add the same filter instance more than once.  Create another instance and add it.");
        }
    }
   
    @Override
    public void onPostAdd(
            IoFilterChain parent, String name, NextFilter nextFilter) throws Exception {
       
        // My previous filter must be an ExecutorFilter.
        IoFilter lastFilter = null;
        for (IoFilterChain.Entry e: parent.getAll()) {
            IoFilter currentFilter = e.getFilter();
            if (currentFilter == this) {
                if (lastFilter instanceof ExecutorFilter) {
                    // Good!
                    break;
                } else {
                    throw new IllegalStateException(
                            ReadThrottleFilter.class.getName() + " must be placed after " +
                            "an " + ExecutorFilter.class.getName() + " in the chain");
                }
            }
           
            lastFilter = currentFilter;
        }
       
        // Add an entering filter before the ExecutorFilter.
        parent.getEntry(lastFilter).addBefore(name + ".preprocessor", enterFilter);
       
        int previousSessionCount = sessionCount.getAndIncrement();
        if (previousSessionCount == 0) {
            synchronized (resumeOthersTask) {
                resumeOthersFuture = executor.scheduleWithFixedDelay(
                        resumeOthersTask, 3000, 3000, TimeUnit.MILLISECONDS);
            }
        }
    }

    @Override
    public void onPostRemove(
            IoFilterChain parent, String name, NextFilter nextFilter) throws Exception {
        // Remove the enter filter together.
        try {
            parent.remove(enterFilter);
        } catch (Exception e) {
            // Ignore.
        }
       
        int currentSessionCount = sessionCount.decrementAndGet();
        if (currentSessionCount == 0) {
            synchronized (resumeOthersTask) {
                resumeOthersFuture.cancel(false);
                resumeOthersFuture = null;
            }
        }
    }

    @Override
    public void messageReceived(
            NextFilter nextFilter, IoSession session, Object message) throws Exception {
        exit(session, estimateSize(message));
        nextFilter.messageReceived(session, message);
    }

    @Override
    public void filterSetTrafficMask(NextFilter nextFilter, IoSession session,
            TrafficMask trafficMask) throws Exception {
       
        if (trafficMask.isReadable()) {
            State state = getState(session);
            boolean suspendedRead;
            synchronized (state) {
                suspendedRead = state.suspendedRead;
            }
           
            // Suppress resumeRead() if read is suspended by this filter.
            if (suspendedRead) {
                trafficMask = trafficMask.and(TrafficMask.WRITE);
            }
        }
       
        nextFilter.filterSetTrafficMask(session, trafficMask);
    }

    private class EnterFilter extends IoFilterAdapter {
        @Override
        public void onPreRemove(
                IoFilterChain parent, String name, NextFilter nextFilter) throws Exception {
            // Remove the exit filter together.
            try {
                parent.remove(ReadThrottleFilter.this);
            } catch (Exception e) {
                // Ignore.
            }
        }
       
        @Override
        public void onPostRemove(
                IoFilterChain parent, String name, NextFilter nextFilter) throws Exception {
            parent.getSession().removeAttribute(STATE);
        }
   
        @Override
        public void messageReceived(
                NextFilter nextFilter, IoSession session, Object message) throws Exception {
            enter(session, estimateSize(message));
            nextFilter.messageReceived(session, message);
        }
    }
   
    private int estimateSize(Object message) {
        int size = messageSizeEstimator.estimateSize(message);
        if (size < 0) {
            throw new IllegalStateException(
                    MessageSizeEstimator.class.getSimpleName() + " returned " +
                    "a negative value (" + size + "): " + message);
        }
        return size;
    }

    private void enter(IoSession session, int size) {
        State state = getState(session);

        int globalBufferSize = ReadThrottleFilter.globalBufferSize.addAndGet(size);
        int serviceBufferSize = increaseServiceBufferSize(session.getService(), size);

        int maxGlobalBufferSize = this.maxGlobalBufferSize;
        int maxServiceBufferSize = this.maxServiceBufferSize;
        int maxSessionBufferSize = this.maxSessionBufferSize;
       
        ReadThrottlePolicy policy = getPolicy();
       
        boolean enforcePolicy = false;
        int sessionBufferSize;
        synchronized (state) {
            sessionBufferSize = (state.sessionBufferSize += size);
            if ((maxSessionBufferSize != 0 && sessionBufferSize >= maxSessionBufferSize) ||
                (maxServiceBufferSize != 0 && serviceBufferSize >= maxServiceBufferSize) ||
                (maxGlobalBufferSize  != 0 && globalBufferSize  >= maxGlobalBufferSize)) {
                enforcePolicy = true;
                switch (policy) {
                case EXCEPTION:
                case BLOCK:
                    state.suspendedRead = true;
                }
            }
        }

        if (logger.isDebugEnabled()) {
            logger.debug(getMessage(session, "  Entered - "));
        }
       
        if (enforcePolicy) {
            switch (policy) {
            case CLOSE:
                log(session, state);
                session.close();
                raiseException(session);
                break;
            case EXCEPTION:
                suspend(session, state, logger);
                raiseException(session);
                break;
            case BLOCK:
                suspend(session, state, logger);
                break;
            case LOG:
                log(session, state);
                break;
            }
        }
    }

    private void suspend(IoSession session, State state, Logger logger) {
        log(session, state);
        session.suspendRead();
        if (logger.isDebugEnabled()) {
            logger.debug(getMessage(session, "Suspended - "));
        }
    }
   
    private void exit(IoSession session, int size) {
        State state = getState(session);

        int globalBufferSize = ReadThrottleFilter.globalBufferSize.addAndGet(-size);
        if (globalBufferSize < 0) {
            throw new IllegalStateException("globalBufferSize: " + globalBufferSize);
        }
       
        int serviceBufferSize = increaseServiceBufferSize(session.getService(), -size);
        if (serviceBufferSize < 0) {
            throw new IllegalStateException("serviceBufferSize: " + serviceBufferSize);
        }

        int maxGlobalBufferSize = this.maxGlobalBufferSize;
        int maxServiceBufferSize = this.maxServiceBufferSize;
        int maxSessionBufferSize = this.maxSessionBufferSize;
       
        int sessionBufferSize;
       
        boolean enforcePolicy = false;
        synchronized (state) {
            sessionBufferSize = (state.sessionBufferSize -= size);
            if (sessionBufferSize < 0) {
                throw new IllegalStateException("sessionBufferSize: " + sessionBufferSize);
            }
            if ((maxGlobalBufferSize == 0 || globalBufferSize < maxGlobalBufferSize) &&
                (maxServiceBufferSize == 0 || serviceBufferSize < maxServiceBufferSize) &&
                (maxSessionBufferSize == 0 || sessionBufferSize < maxSessionBufferSize)) {
                state.suspendedRead = false;
                enforcePolicy = true;
            }
        }
       
        if (logger.isDebugEnabled()) {
            logger.debug(getMessage(session, "   Exited - "));
        }
       
        if (enforcePolicy) {
            session.resumeRead();
            if (logger.isDebugEnabled()) {
                logger.debug(getMessage(session, "  Resumed - "));
            }
        }
       
        resumeOthers();
    }
   
    private void resumeOthers() {
        long currentTime = System.currentTimeMillis();
       
        // Try to resume other sessions every other second.
        boolean resumeOthers;
        synchronized (globalResumeLock) {
            if (currentTime - lastGlobalResumeTime > 1000) {
                lastGlobalResumeTime = currentTime;
                resumeOthers = true;
            } else {
                resumeOthers = false;
            }
        }
       
        if (resumeOthers) {
            int maxGlobalBufferSize = this.maxGlobalBufferSize;
            if (maxGlobalBufferSize == 0 || globalBufferSize.get() < maxGlobalBufferSize) {
                List<IoService> inactiveServices = null;
                for (IoService service: serviceBufferSizes.keySet()) {
                    resumeService(service);
                   
                    if (!service.isActive()) {
                        if (inactiveServices == null) {
                            inactiveServices = new ArrayList<IoService>();
                        }
                        inactiveServices.add(service);
                    }
                   
                    // Remove inactive services from the map.
                    if (inactiveServices != null) {
                        for (IoService s: inactiveServices) {
                            serviceBufferSizes.remove(s);
                        }
                    }

                    synchronized (globalResumeLock) {
                        lastGlobalResumeTime = System.currentTimeMillis();
                    }
                }
            }
        }
    }
   
    private void resumeService(IoService service) {
        int maxServiceBufferSize = this.maxServiceBufferSize;
        if (maxServiceBufferSize == 0 || getServiceBufferSize(service) < maxServiceBufferSize) {
            for (IoSession session: service.getManagedSessions()) {
                resume(session);
            }
        }
    }
   
    private void resume(IoSession session) {
        State state = (State) session.getAttribute(STATE);
        if (state == null) {
            return;
        }
       
        int maxSessionBufferSize = this.maxSessionBufferSize;
        boolean resume = false;
        synchronized (state) {
            if ((maxSessionBufferSize == 0 || state.sessionBufferSize < maxSessionBufferSize)) {
                state.suspendedRead = false;
                resume = true;
            }
        }

        if (resume) {
            session.resumeRead();
            if (logger.isDebugEnabled()) {
                logger.debug(getMessage(session, "  Resumed - "));
            }
        }
    }

    private void log(IoSession session, State state) {
        long currentTime = System.currentTimeMillis();
       
        // Prevent log flood by logging every 3 seconds.
        boolean log;
        synchronized (state.logLock) {
            if (currentTime - state.lastLogTime > 3000) {
                state.lastLogTime = currentTime;
                log = true;
            } else {
                log = false;
            }
        }
       
        if (log) {
            logger.warn(getMessage(session));
        }
    }
   
    private void raiseException(IoSession session) {
        throw new ReadFloodException(getMessage(session));
    }
   
    private String getMessage(IoSession session) {
        return getMessage(session, "Read buffer flooded - ");
    }
   
    private String getMessage(IoSession session, String prefix) {
        int  sessionLimit = maxSessionBufferSize;
        int  serviceLimit = maxServiceBufferSize;
        int  globalLimit  = maxGlobalBufferSize;

        StringBuilder buf = new StringBuilder(512);
        buf.append(prefix);
        buf.append("session: ");
        if (sessionLimit != 0) {
            buf.append(getSessionBufferSize(session));
            buf.append(" / ");
            buf.append(sessionLimit);
            buf.append(" bytes, ");
        } else {
            buf.append(getSessionBufferSize(session));
            buf.append(" / unlimited bytes, ");
        }
       
        buf.append("service: ");
        if (serviceLimit != 0) {
            buf.append(getServiceBufferSize(session.getService()));
            buf.append(" / ");
            buf.append(serviceLimit);
            buf.append(" bytes, ");
        } else {
            buf.append(getServiceBufferSize(session.getService()));
            buf.append(" / unlimited bytes, ");
        }
       
        buf.append("global: ");
        if (globalLimit != 0) {
            buf.append(getGlobalBufferSize());
            buf.append(" / ");
            buf.append(globalLimit);
            buf.append(" bytes.");
        } else {
            buf.append(getGlobalBufferSize());
            buf.append(" / unlimited bytes.");
        }
       
        return buf.toString();
    }
   
    private State getState(IoSession session) {
        State state = (State) session.getAttribute(STATE);
        if (state == null) {
            state = new State();
            State oldState = (State) session.setAttributeIfAbsent(STATE, state);
            if (oldState != null) {
                state = oldState;
            }
        }
        return state;
    }
   
    @Override
    public String toString() {
        return String.valueOf(getGlobalBufferSize()) + '/' + getMaxGlobalBufferSize();
    }

    private static class State {
        private int sessionBufferSize;
        private boolean suspendedRead;

        private final Object logLock = new Object();
        private long lastLogTime = 0;
    }
}
TOP

Related Classes of org.apache.mina.filter.traffic.ReadThrottleFilter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.