Package io.netty.channel.local

Source Code of io.netty.channel.local.LocalTransportThreadModelTest$MessageForwarder3

/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.channel.local;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoopGroup;
import io.netty.channel.EventLoopGroup;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.DefaultEventExecutorGroup;
import io.netty.util.concurrent.DefaultExecutorServiceFactory;
import io.netty.util.concurrent.EventExecutorGroup;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;

import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicReference;

public class LocalTransportThreadModelTest {

    private static EventLoopGroup group;
    private static LocalAddress localAddr;

    @BeforeClass
    public static void init() {
        // Configure a test server
        group = new DefaultEventLoopGroup();
        ServerBootstrap sb = new ServerBootstrap();
        sb.group(group)
          .channel(LocalServerChannel.class)
          .childHandler(new ChannelInitializer<LocalChannel>() {
              @Override
              public void initChannel(LocalChannel ch) throws Exception {
                  ch.pipeline().addLast(new ChannelHandlerAdapter() {
                      @Override
                      public void channelRead(ChannelHandlerContext ctx, Object msg) {
                          // Discard
                          ReferenceCountUtil.release(msg);
                      }
                  });
              }
          });

        localAddr = (LocalAddress) sb.bind(LocalAddress.ANY).syncUninterruptibly().channel().localAddress();
    }

    @AfterClass
    public static void destroy() throws Exception {
        group.shutdownGracefully().sync();
    }

    @Test(timeout = 30000)
    @Ignore("regression test")
    public void testStagedExecutionMultiple() throws Throwable {
        for (int i = 0; i < 10; i ++) {
            testStagedExecution();
        }
    }

    @Test(timeout = 5000)
    public void testStagedExecution() throws Throwable {
        EventLoopGroup l = new DefaultEventLoopGroup(4, new DefaultExecutorServiceFactory("l"));
        EventExecutorGroup e1 = new DefaultEventExecutorGroup(4, new DefaultExecutorServiceFactory("e1"));
        EventExecutorGroup e2 = new DefaultEventExecutorGroup(4, new DefaultExecutorServiceFactory("e2"));
        ThreadNameAuditor h1 = new ThreadNameAuditor();
        ThreadNameAuditor h2 = new ThreadNameAuditor();
        ThreadNameAuditor h3 = new ThreadNameAuditor(true);

        Channel ch = new LocalChannel();
        // With no EventExecutor specified, h1 will be always invoked by EventLoop 'l'.
        ch.pipeline().addLast(h1);
        // h2 will be always invoked by EventExecutor 'e1'.
        ch.pipeline().addLast(e1, h2);
        // h3 will be always invoked by EventExecutor 'e2'.
        ch.pipeline().addLast(e2, h3);

        l.register(ch).sync().channel().connect(localAddr).sync();

        // Fire inbound events from all possible starting points.
        ch.pipeline().fireChannelRead("1");
        ch.pipeline().context(h1).fireChannelRead("2");
        ch.pipeline().context(h2).fireChannelRead("3");
        ch.pipeline().context(h3).fireChannelRead("4");
        // Fire outbound events from all possible starting points.
        ch.pipeline().write("5");
        ch.pipeline().context(h3).write("6");
        ch.pipeline().context(h2).write("7");
        ch.pipeline().context(h1).writeAndFlush("8").sync();

        ch.close().sync();

        // Wait until all events are handled completely.
        while (h1.outboundThreadNames.size() < 3 || h3.inboundThreadNames.size() < 3 ||
               h1.removalThreadNames.size() < 1) {
            if (h1.exception.get() != null) {
                throw h1.exception.get();
            }
            if (h2.exception.get() != null) {
                throw h2.exception.get();
            }
            if (h3.exception.get() != null) {
                throw h3.exception.get();
            }

            Thread.sleep(10);
        }

        String currentName = Thread.currentThread().getName();

        try {
            // Events should never be handled from the current thread.
            Assert.assertFalse(h1.inboundThreadNames.contains(currentName));
            Assert.assertFalse(h2.inboundThreadNames.contains(currentName));
            Assert.assertFalse(h3.inboundThreadNames.contains(currentName));
            Assert.assertFalse(h1.outboundThreadNames.contains(currentName));
            Assert.assertFalse(h2.outboundThreadNames.contains(currentName));
            Assert.assertFalse(h3.outboundThreadNames.contains(currentName));
            Assert.assertFalse(h1.removalThreadNames.contains(currentName));
            Assert.assertFalse(h2.removalThreadNames.contains(currentName));
            Assert.assertFalse(h3.removalThreadNames.contains(currentName));

            // Assert that events were handled by the correct executor.
            for (String name: h1.inboundThreadNames) {
                Assert.assertTrue(name.startsWith("l-"));
            }
            for (String name: h2.inboundThreadNames) {
                Assert.assertTrue(name.startsWith("e1-"));
            }
            for (String name: h3.inboundThreadNames) {
                Assert.assertTrue(name.startsWith("e2-"));
            }
            for (String name: h1.outboundThreadNames) {
                Assert.assertTrue(name.startsWith("l-"));
            }
            for (String name: h2.outboundThreadNames) {
                Assert.assertTrue(name.startsWith("e1-"));
            }
            for (String name: h3.outboundThreadNames) {
                Assert.assertTrue(name.startsWith("e2-"));
            }
            for (String name: h1.removalThreadNames) {
                Assert.assertTrue(name.startsWith("l-"));
            }
            for (String name: h2.removalThreadNames) {
                Assert.assertTrue(name.startsWith("e1-"));
            }
            for (String name: h3.removalThreadNames) {
                Assert.assertTrue(name.startsWith("e2-"));
            }

            // Count the number of events
            Assert.assertEquals(1, h1.inboundThreadNames.size());
            Assert.assertEquals(2, h2.inboundThreadNames.size());
            Assert.assertEquals(3, h3.inboundThreadNames.size());
            Assert.assertEquals(3, h1.outboundThreadNames.size());
            Assert.assertEquals(2, h2.outboundThreadNames.size());
            Assert.assertEquals(1, h3.outboundThreadNames.size());
            Assert.assertEquals(1, h1.removalThreadNames.size());
            Assert.assertEquals(1, h2.removalThreadNames.size());
            Assert.assertEquals(1, h3.removalThreadNames.size());
        } catch (AssertionError e) {
            System.out.println("H1I: " + h1.inboundThreadNames);
            System.out.println("H2I: " + h2.inboundThreadNames);
            System.out.println("H3I: " + h3.inboundThreadNames);
            System.out.println("H1O: " + h1.outboundThreadNames);
            System.out.println("H2O: " + h2.outboundThreadNames);
            System.out.println("H3O: " + h3.outboundThreadNames);
            System.out.println("H1R: " + h1.removalThreadNames);
            System.out.println("H2R: " + h2.removalThreadNames);
            System.out.println("H3R: " + h3.removalThreadNames);
            throw e;
        } finally {
            l.shutdownGracefully();
            e1.shutdownGracefully();
            e2.shutdownGracefully();

            l.terminationFuture().sync();
            e1.terminationFuture().sync();
            e2.terminationFuture().sync();
        }
    }

    @Test(timeout = 30000)
    @Ignore
    public void testConcurrentMessageBufferAccess() throws Throwable {
        EventLoopGroup l0 = new DefaultEventLoopGroup(4, new DefaultExecutorServiceFactory("l0"));
        EventExecutorGroup e1 = new DefaultEventExecutorGroup(4, new DefaultExecutorServiceFactory("e1"));
        EventExecutorGroup e2 = new DefaultEventExecutorGroup(4, new DefaultExecutorServiceFactory("e2"));
        EventExecutorGroup e3 = new DefaultEventExecutorGroup(4, new DefaultExecutorServiceFactory("e3"));
        EventExecutorGroup e4 = new DefaultEventExecutorGroup(4, new DefaultExecutorServiceFactory("e4"));
        EventExecutorGroup e5 = new DefaultEventExecutorGroup(4, new DefaultExecutorServiceFactory("e5"));

        try {
            final MessageForwarder1 h1 = new MessageForwarder1();
            final MessageForwarder2 h2 = new MessageForwarder2();
            final MessageForwarder3 h3 = new MessageForwarder3();
            final MessageForwarder1 h4 = new MessageForwarder1();
            final MessageForwarder2 h5 = new MessageForwarder2();
            final MessageDiscarder  h6 = new MessageDiscarder();

            final Channel ch = new LocalChannel();

            // inbound:  int -> byte[4] -> int -> int -> byte[4] -> int -> /dev/null
            // outbound: int -> int -> byte[4] -> int -> int -> byte[4] -> /dev/null
            ch.pipeline().addLast(h1)
                         .addLast(e1, h2)
                         .addLast(e2, h3)
                         .addLast(e3, h4)
                         .addLast(e4, h5)
                         .addLast(e5, h6);

            l0.register(ch).sync().channel().connect(localAddr).sync();

            final int ROUNDS = 1024;
            final int ELEMS_PER_ROUNDS = 8192;
            final int TOTAL_CNT = ROUNDS * ELEMS_PER_ROUNDS;
            for (int i = 0; i < TOTAL_CNT;) {
                final int start = i;
                final int end = i + ELEMS_PER_ROUNDS;
                i = end;

                ch.eventLoop().execute(new Runnable() {
                    @Override
                    public void run() {
                        for (int j = start; j < end; j ++) {
                            ch.pipeline().fireChannelRead(Integer.valueOf(j));
                        }
                    }
                });
            }

            while (h1.inCnt < TOTAL_CNT || h2.inCnt < TOTAL_CNT || h3.inCnt < TOTAL_CNT ||
                    h4.inCnt < TOTAL_CNT || h5.inCnt < TOTAL_CNT || h6.inCnt < TOTAL_CNT) {
                if (h1.exception.get() != null) {
                    throw h1.exception.get();
                }
                if (h2.exception.get() != null) {
                    throw h2.exception.get();
                }
                if (h3.exception.get() != null) {
                    throw h3.exception.get();
                }
                if (h4.exception.get() != null) {
                    throw h4.exception.get();
                }
                if (h5.exception.get() != null) {
                    throw h5.exception.get();
                }
                if (h6.exception.get() != null) {
                    throw h6.exception.get();
                }
                Thread.sleep(10);
            }

            for (int i = 0; i < TOTAL_CNT;) {
                final int start = i;
                final int end = i + ELEMS_PER_ROUNDS;
                i = end;

                ch.pipeline().context(h6).executor().execute(new Runnable() {
                    @Override
                    public void run() {
                        for (int j = start; j < end; j ++) {
                            ch.write(Integer.valueOf(j));
                        }
                        ch.flush();
                    }
                });
            }

            while (h1.outCnt < TOTAL_CNT || h2.outCnt < TOTAL_CNT || h3.outCnt < TOTAL_CNT ||
                    h4.outCnt < TOTAL_CNT || h5.outCnt < TOTAL_CNT || h6.outCnt < TOTAL_CNT) {
                if (h1.exception.get() != null) {
                    throw h1.exception.get();
                }
                if (h2.exception.get() != null) {
                    throw h2.exception.get();
                }
                if (h3.exception.get() != null) {
                    throw h3.exception.get();
                }
                if (h4.exception.get() != null) {
                    throw h4.exception.get();
                }
                if (h5.exception.get() != null) {
                    throw h5.exception.get();
                }
                if (h6.exception.get() != null) {
                    throw h6.exception.get();
                }
                Thread.sleep(10);
            }

            ch.close().sync();
        } finally {
            l0.shutdownGracefully();
            e1.shutdownGracefully();
            e2.shutdownGracefully();
            e3.shutdownGracefully();
            e4.shutdownGracefully();
            e5.shutdownGracefully();

            l0.terminationFuture().sync();
            e1.terminationFuture().sync();
            e2.terminationFuture().sync();
            e3.terminationFuture().sync();
            e4.terminationFuture().sync();
            e5.terminationFuture().sync();
        }
    }

    private static class ThreadNameAuditor extends ChannelHandlerAdapter {

        private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();

        private final Queue<String> inboundThreadNames = new ConcurrentLinkedQueue<String>();
        private final Queue<String> outboundThreadNames = new ConcurrentLinkedQueue<String>();
        private final Queue<String> removalThreadNames = new ConcurrentLinkedQueue<String>();
        private final boolean discard;

        ThreadNameAuditor() {
            this(false);
        }

        ThreadNameAuditor(boolean discard) {
            this.discard = discard;
        }

        @Override
        public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
            removalThreadNames.add(Thread.currentThread().getName());
        }

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            inboundThreadNames.add(Thread.currentThread().getName());
            if (!discard) {
                ctx.fireChannelRead(msg);
            }
        }

        @Override
        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            outboundThreadNames.add(Thread.currentThread().getName());
            ctx.write(msg, promise);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            exception.compareAndSet(null, cause);
            System.err.print('[' + Thread.currentThread().getName() + "] ");
            cause.printStackTrace();
            super.exceptionCaught(ctx, cause);
        }
    }

    /**
     * Converts integers into a binary stream.
     */
    private static class MessageForwarder1 extends ChannelHandlerAdapter {

        private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
        private volatile int inCnt;
        private volatile int outCnt;
        private volatile Thread t;

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            Thread t = this.t;
            if (t == null) {
                this.t = Thread.currentThread();
            } else {
                assertSameExecutor(t, Thread.currentThread());
            }

            ByteBuf out = ctx.alloc().buffer(4);
            int m = ((Integer) msg).intValue();
            int expected = inCnt ++;
            Assert.assertEquals(expected, m);
            out.writeInt(m);

            ctx.fireChannelRead(out);
        }

        @Override
        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            assertSameExecutor(t, Thread.currentThread());

            // Don't let the write request go to the server-side channel - just swallow.
            boolean swallow = this == ctx.pipeline().first();

            ByteBuf m = (ByteBuf) msg;
            int count = m.readableBytes() / 4;
            for (int j = 0; j < count; j ++) {
                int actual = m.readInt();
                int expected = outCnt ++;
                Assert.assertEquals(expected, actual);
                if (!swallow) {
                    ctx.write(actual);
                }
            }
            ctx.writeAndFlush(Unpooled.EMPTY_BUFFER, promise);
            m.release();
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            exception.compareAndSet(null, cause);
            //System.err.print("[" + Thread.currentThread().getName() + "] ");
            //cause.printStackTrace();
            super.exceptionCaught(ctx, cause);
        }
    }

    /**
     * Converts a binary stream into integers.
     */
    private static class MessageForwarder2 extends ChannelHandlerAdapter {

        private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
        private volatile int inCnt;
        private volatile int outCnt;
        private volatile Thread t;

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            Thread t = this.t;
            if (t == null) {
                this.t = Thread.currentThread();
            } else {
                assertSameExecutor(t, Thread.currentThread());
            }

            ByteBuf m = (ByteBuf) msg;
            int count = m.readableBytes() / 4;
            for (int j = 0; j < count; j ++) {
                int actual = m.readInt();
                int expected = inCnt ++;
                Assert.assertEquals(expected, actual);
                ctx.fireChannelRead(actual);
            }
            m.release();
        }

        @Override
        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            assertSameExecutor(t, Thread.currentThread());

            ByteBuf out = ctx.alloc().buffer(4);
            int m = (Integer) msg;
            int expected = outCnt ++;
            Assert.assertEquals(expected, m);
            out.writeInt(m);

            ctx.write(out, promise);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            exception.compareAndSet(null, cause);
            //System.err.print("[" + Thread.currentThread().getName() + "] ");
            //cause.printStackTrace();
            super.exceptionCaught(ctx, cause);
        }
    }

    /**
     * Simply forwards the received object to the next handler.
     */
    private static class MessageForwarder3 extends ChannelHandlerAdapter {

        private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
        private volatile int inCnt;
        private volatile int outCnt;
        private volatile Thread t;

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            Thread t = this.t;
            if (t == null) {
                this.t = Thread.currentThread();
            } else {
                assertSameExecutor(t, Thread.currentThread());
            }

            int actual = (Integer) msg;
            int expected = inCnt ++;
            Assert.assertEquals(expected, actual);

            ctx.fireChannelRead(msg);
        }

        @Override
        public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            assertSameExecutor(t, Thread.currentThread());

            int actual = (Integer) msg;
            int expected = outCnt ++;
            Assert.assertEquals(expected, actual);

            ctx.write(msg, promise);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            exception.compareAndSet(null, cause);
            System.err.print('[' + Thread.currentThread().getName() + "] ");
            cause.printStackTrace();
            super.exceptionCaught(ctx, cause);
        }
    }

    /**
     * Discards all received messages.
     */
    private static class MessageDiscarder extends ChannelHandlerAdapter {

        private final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
        private volatile int inCnt;
        private volatile int outCnt;
        private volatile Thread t;

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            Thread t = this.t;
            if (t == null) {
                this.t = Thread.currentThread();
            } else {
                assertSameExecutor(t, Thread.currentThread());
            }

            int actual = (Integer) msg;
            int expected = inCnt ++;
            Assert.assertEquals(expected, actual);
        }

        @Override
        public void write(
                ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
            assertSameExecutor(t, Thread.currentThread());

            int actual = (Integer) msg;
            int expected = outCnt ++;
            Assert.assertEquals(expected, actual);
            ctx.write(msg, promise);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            exception.compareAndSet(null, cause);
            //System.err.print("[" + Thread.currentThread().getName() + "] ");
            //cause.printStackTrace();
            super.exceptionCaught(ctx, cause);
        }
    }

    private static void assertSameExecutor(Thread expected, Thread actual) {
        Assert.assertEquals(expected.getName().substring(0, 2), actual.getName().substring(0, 2));
    }
}
TOP

Related Classes of io.netty.channel.local.LocalTransportThreadModelTest$MessageForwarder3

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.