Package org.apache.hama.ipc

Source Code of org.apache.hama.ipc.AsyncServer$Call

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hama.ipc;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.util.ReferenceCountUtil;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.StringUtils;

/**
* An abstract IPC service using netty. IPC calls take a single {@link Writable}
* as a parameter, and return a {@link Writable}*
*
* @see AsyncServer
*/
public abstract class AsyncServer {

  private AuthMethod authMethod;
  static final ByteBuffer HEADER = ByteBuffer.wrap("hrpc".getBytes());
  static int INITIAL_RESP_BUF_SIZE = 1024;
  UserGroupInformation user = null;
  // 1 : Introduce ping and server does not throw away RPCs
  // 3 : Introduce the protocol into the RPC connection header
  // 4 : Introduced SASL security layer
  static final byte CURRENT_VERSION = 4;
  static final int HEADER_LENGTH = 10;
  // follows version is read
  private Configuration conf;
  private final boolean tcpNoDelay; // if T then disable Nagle's Algorithm
  private int backlogLength;;
  InetSocketAddress address;
  private static final Log LOG = LogFactory.getLog(AsyncServer.class);
  private static int NIO_BUFFER_LIMIT = 8 * 1024;
  private final int maxRespSize;
  static final String IPC_SERVER_RPC_MAX_RESPONSE_SIZE_KEY = "ipc.server.max.response.size";
  static final int IPC_SERVER_RPC_MAX_RESPONSE_SIZE_DEFAULT = 1024 * 1024;

  private static final ThreadLocal<AsyncServer> SERVER = new ThreadLocal<AsyncServer>();
  private int port; // port we listen on
  private Class<? extends Writable> paramClass; // class of call parameters
  // Configure the server.(constructor is thread num)
  private EventLoopGroup bossGroup = new NioEventLoopGroup(1);
  private EventLoopGroup workerGroup = new NioEventLoopGroup();
  private static final Map<String, Class<?>> PROTOCOL_CACHE = new ConcurrentHashMap<String, Class<?>>();
  private ExceptionsHandler exceptionsHandler = new ExceptionsHandler();

  static Class<?> getProtocolClass(String protocolName, Configuration conf)
      throws ClassNotFoundException {
    Class<?> protocol = PROTOCOL_CACHE.get(protocolName);
    if (protocol == null) {
      protocol = conf.getClassByName(protocolName);
      PROTOCOL_CACHE.put(protocolName, protocol);
    }
    return protocol;
  }

  /**
   * Getting address
   *
   * @return InetSocketAddress
   */
  public InetSocketAddress getAddress() {
    return address;
  }

  /**
   * Returns the server instance called under or null. May be called under
   * {@link #call(Writable, long)} implementations, and under {@link Writable}
   * methods of paramters and return values. Permits applications to access the
   * server context.
   *
   * @return NioServer
   */
  public static AsyncServer get() {
    return SERVER.get();
  }

  /**
   * Constructs a server listening on the named port and address. Parameters
   * passed must be of the named class. The
   * <code>handlerCount</handlerCount> determines
   * the number of handler threads that will be used to process calls.
   *
   * @param bindAddress
   * @param port
   * @param paramClass
   * @param handlerCount
   * @param conf
   * @throws IOException
   */
  protected AsyncServer(String bindAddress, int port,
      Class<? extends Writable> paramClass, int handlerCount, Configuration conf)
      throws IOException {
    this(bindAddress, port, paramClass, handlerCount, conf, Integer
        .toString(port), null);
  }

  protected AsyncServer(String bindAddress, int port,
      Class<? extends Writable> paramClass, int handlerCount,
      Configuration conf, String serverName) throws IOException {
    this(bindAddress, port, paramClass, handlerCount, conf, serverName, null);
  }

  protected AsyncServer(String bindAddress, int port,
      Class<? extends Writable> paramClass, int handlerCount,
      Configuration conf, String serverName,
      SecretManager<? extends TokenIdentifier> secretManager)
      throws IOException {
    this.conf = conf;
    this.port = port;
    this.address = new InetSocketAddress(bindAddress, port);
    this.paramClass = paramClass;
    this.maxRespSize = conf.getInt(IPC_SERVER_RPC_MAX_RESPONSE_SIZE_KEY,
        IPC_SERVER_RPC_MAX_RESPONSE_SIZE_DEFAULT);

    this.tcpNoDelay = conf.getBoolean("ipc.server.tcpnodelay", true);
    this.backlogLength = conf.getInt("ipc.server.listen.queue.size", 100);
  }

  /** start server listener */
  public void start() {
    new NioServerListener().start();
  }

  private class NioServerListener extends Thread {

    /**
     * Configure and start nio server
     */
    @Override
    public void run() {
      SERVER.set(AsyncServer.this);
      try {
        // ServerBootstrap is a helper class that sets up a server
        ServerBootstrap b = new ServerBootstrap();
        b.group(bossGroup, workerGroup)
            .channel(NioServerSocketChannel.class)
            .option(ChannelOption.SO_BACKLOG, backlogLength)
            .childOption(ChannelOption.MAX_MESSAGES_PER_READ, NIO_BUFFER_LIMIT)
            .childOption(ChannelOption.TCP_NODELAY, tcpNoDelay)
            .childOption(ChannelOption.SO_KEEPALIVE, true)
            .childOption(ChannelOption.SO_RCVBUF, 30 * 1024 * 1024)
            .childOption(ChannelOption.RCVBUF_ALLOCATOR,
                new FixedRecvByteBufAllocator(100 * 1024))

            .childHandler(new ChannelInitializer<SocketChannel>() {
              @Override
              public void initChannel(SocketChannel ch) throws Exception {
                ChannelPipeline p = ch.pipeline();
                // Register accumulation processing handler
                p.addLast(new NioFrameDecoder(100 * 1024 * 1024, 0, 4, 0, 0));
                // Register message processing handler
                p.addLast(new NioServerInboundHandler());
              }
            });

        // Bind and start to accept incoming connections.
        ChannelFuture f = b.bind(port).sync();
        LOG.info("AsyncServer startup");
        // Wait until the server socket is closed.
        f.channel().closeFuture().sync();
      } catch (Exception e) {
        e.printStackTrace();
      } finally {
        // Shut down Server gracefully
        bossGroup.shutdownGracefully();
        workerGroup.shutdownGracefully();
      }
    }
  }

  /** Stops the server gracefully. */
  public void stop() {
    if (bossGroup != null && !bossGroup.isTerminated()) {
      bossGroup.shutdownGracefully();
    }
    if (workerGroup != null && !workerGroup.isTerminated()) {
      workerGroup.shutdownGracefully();
    }
    LOG.info("AsyncServer gracefully shutdown");
  }

  /**
   * This class dynamically accumulate the recieved data by the value of the
   * length field in the message
   */
  public class NioFrameDecoder extends LengthFieldBasedFrameDecoder {

    /**
     * @param maxFrameLength - the maximum length of the frame
     * @param lengthFieldOffset - the offset of the length field
     * @param lengthFieldLength - the length of the length field
     * @param lengthAdjustment - the compensation value to add to the value of
     *          the length field
     * @param initialBytesToStrip - the number of first bytes to strip out from
     *          the decoded frame
     */
    public NioFrameDecoder(int maxFrameLength, int lengthFieldOffset,
        int lengthFieldLength, int lengthAdjustment, int initialBytesToStrip) {
      super(maxFrameLength, lengthFieldOffset, lengthFieldLength,
          lengthAdjustment, initialBytesToStrip);
    }

    /**
     * Decode(Accumulate) the from one ByteBuf to an other
     *
     * @param ctx
     * @param in
     */
    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in)
        throws Exception {
      ByteBuf recvBuff = (ByteBuf) super.decode(ctx, in);
      if (recvBuff == null) {
        return null;
      }
      return recvBuff;
    }
  }

  /**
   * This class process received message from client and send response message.
   */
  private class NioServerInboundHandler extends ChannelInboundHandlerAdapter {
    ConnectionHeader header = new ConnectionHeader();
    Class<?> protocol;
    private String errorClass = null;
    private String error = null;
    private boolean rpcHeaderRead = false; // if initial rpc header is read
    private boolean headerRead = false; // if the connection header that follows
                                        // version is read.

    /**
     * Be invoked only one when a connection is established and ready to
     * generate traffic
     *
     * @param ctx
     */
    @Override
    public void channelActive(ChannelHandlerContext ctx) {
      SERVER.set(AsyncServer.this);
    }

    /**
     * Process a recieved message from client. This method is called with the
     * received message, whenever new data is received from a client.
     *
     * @param ctx
     * @param cause
     */
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
      ByteBuffer dataLengthBuffer = ByteBuffer.allocate(4);
      ByteBuf byteBuf = (ByteBuf) msg;

      ByteBuffer data = null;
      ByteBuffer rpcHeaderBuffer = null;
      try {
        while (true) {
          Call call = null;
          errorClass = null;
          error = null;
          try {
            if (dataLengthBuffer.remaining() > 0 && byteBuf.isReadable()) {
              byteBuf.readBytes(dataLengthBuffer);
              if (dataLengthBuffer.remaining() > 0 && byteBuf.isReadable()) {
                return;
              }
            } else {
              return;
            }

            // read rpcHeader
            if (!rpcHeaderRead) {
              // Every connection is expected to send the header.
              if (rpcHeaderBuffer == null) {
                dataLengthBuffer = null;
                dataLengthBuffer = ByteBuffer.allocate(4);
                byteBuf.readBytes(dataLengthBuffer);
                rpcHeaderBuffer = ByteBuffer.allocate(2);
              }
              byteBuf.readBytes(rpcHeaderBuffer);
              if (!rpcHeaderBuffer.hasArray()
                  || rpcHeaderBuffer.remaining() > 0) {
                return;
              }
              int version = rpcHeaderBuffer.get(0);
              byte[] method = new byte[] { rpcHeaderBuffer.get(1) };
              try {
                authMethod = AuthMethod.read(new DataInputStream(
                    new ByteArrayInputStream(method)));
                dataLengthBuffer.flip();
              } catch (IOException ioe) {
                errorClass = ioe.getClass().getName();
                error = StringUtils.stringifyException(ioe);
              }

              if (!HEADER.equals(dataLengthBuffer)
                  || version != CURRENT_VERSION) {
                LOG.warn("Incorrect header or version mismatch from "
                    + address.getHostName() + ":" + address.getPort()
                    + " got version " + version + " expected version "
                    + CURRENT_VERSION);
                return;
              }
              dataLengthBuffer.clear();
              if (authMethod == null) {
                throw new RuntimeException(
                    "Unable to read authentication method");
              }
              rpcHeaderBuffer = null;
              rpcHeaderRead = true;
              continue;
            }

            // read data length and allocate buffer;
            if (data == null) {
              dataLengthBuffer.flip();
              int dataLength = dataLengthBuffer.getInt();
              if (dataLength < 0) {
                LOG.warn("Unexpected data length " + dataLength + "!! from "
                    + address.getHostName());
              }
              data = ByteBuffer.allocate(dataLength);
            }

            // read received data
            byteBuf.readBytes(data);
            if (data.remaining() == 0) {
              dataLengthBuffer.clear();
              data.flip();
              boolean isHeaderRead = headerRead;
              call = processOneRpc(data.array());
              data = null;
              if (!isHeaderRead) {
                continue;
              }
            }
          } catch (OutOfMemoryError oome) {
            // we can run out of memory if we have too many threads
            // log the event and sleep for a minute and give
            // some thread(s) a chance to finish
            //
            LOG.warn("Out of Memory in server select", oome);
            try {
              Thread.sleep(60000);
              errorClass = oome.getClass().getName();
              error = StringUtils.stringifyException(oome);
            } catch (Exception ie) {
            }
          } catch (Exception e) {
            LOG.warn("Exception in Responder "
                + StringUtils.stringifyException(e));
            errorClass = e.getClass().getName();
            error = StringUtils.stringifyException(e);
          }
          sendResponse(ctx, call);
        }
      } finally {
        ReferenceCountUtil.release(msg);
      }
    }

    /**
     * Send response data to client
     *
     * @param ctx
     * @param call
     */
    private void sendResponse(ChannelHandlerContext ctx, Call call) {
      ByteArrayOutputStream buf = new ByteArrayOutputStream(
          INITIAL_RESP_BUF_SIZE);
      Writable value = null;
      try {
        value = call(protocol, call.param, call.timestamp);
      } catch (Throwable e) {
        String logMsg = this.getClass().getName() + ", call " + call
            + ": error: " + e;
        if (e instanceof RuntimeException || e instanceof Error) {
          // These exception types indicate something is probably wrong
          // on the server side, as opposed to just a normal exceptional
          // result.
          LOG.warn(logMsg, e);
        } else if (exceptionsHandler.isTerse(e.getClass())) {
          // Don't log the whole stack trace of these exceptions.
          // Way too noisy!
          LOG.info(logMsg);
        } else {
          LOG.info(logMsg, e);
        }
        errorClass = e.getClass().getName();
        error = StringUtils.stringifyException(e);
      }
      try {
        setupResponse(buf, call, (error == null) ? Status.SUCCESS
            : Status.ERROR, value, errorClass, error);
        if (buf.size() > maxRespSize) {
          LOG.warn("Large response size " + buf.size() + " for call "
              + call.toString());
          buf = new ByteArrayOutputStream(INITIAL_RESP_BUF_SIZE);
        }
        // send response data;
        channelWrite(ctx, call.response);
      } catch (Exception e) {
        LOG.info(this.getClass().getName() + " caught: "
            + StringUtils.stringifyException(e));
        error = null;
      } finally {
        IOUtils.closeStream(buf);
      }
    }

    /**
     * read header or data
     *
     * @param buf
     * @return
     */
    private Call processOneRpc(byte[] buf) throws IOException {
      if (headerRead) {
        return processData(buf);
      } else {
        processHeader(buf);
        headerRead = true;
        return null;
      }
    }

    /**
     * Reads the connection header following version
     *
     * @param buf buffer
     */
    private void processHeader(byte[] buf) {
      DataInputStream in = new DataInputStream(new ByteArrayInputStream(buf));
      try {
        header.readFields(in);
        String protocolClassName = header.getProtocol();
        if (protocolClassName != null) {
          protocol = getProtocolClass(header.getProtocol(), conf);
        }
      } catch (Exception e) {
        throw new RuntimeException(e);
      } finally {
        IOUtils.closeStream(in);
      }

      UserGroupInformation protocolUser = header.getUgi();
      user = protocolUser;
    }

    /**
     *
     * Reads the received data, create call object;
     *
     * @param buf buffer to serialize the response into
     * @return the IPC Call
     */
    private Call processData(byte[] buf) {
      DataInputStream dis = new DataInputStream(new ByteArrayInputStream(buf));
      try {
        int id = dis.readInt(); // try to read an id

        if (LOG.isDebugEnabled())
          LOG.debug(" got #" + id);
        Writable param = ReflectionUtils.newInstance(paramClass, conf);
        param.readFields(dis); // try to read param data

        Call call = new Call(id, param, this);

        return call;
      } catch (Exception e) {
        throw new RuntimeException(e);
      } finally {
        IOUtils.closeStream(dis);
      }
    }
  }

  /**
   * Setup response for the IPC Call.
   *
   * @param response buffer to serialize the response into
   * @param call {@link Call} to which we are setting up the response
   * @param status {@link Status} of the IPC call
   * @param rv return value for the IPC Call, if the call was successful
   * @param errorClass error class, if the the call failed
   * @param error error message, if the call failed
   * @throws IOException
   */
  private void setupResponse(ByteArrayOutputStream response, Call call,
      Status status, Writable rv, String errorClass, String error)
      throws IOException {
    response.reset();
    DataOutputStream out = new DataOutputStream(response);
    out.writeInt(call.id); // write call id
    out.writeInt(status.state); // write status

    if (status == Status.SUCCESS) {
      rv.write(out);
    } else {
      WritableUtils.writeString(out, errorClass);
      WritableUtils.writeString(out, error);
    }
    call.setResponse(ByteBuffer.wrap(response.toByteArray()));
    IOUtils.closeStream(out);
  }

  /**
   * This is a wrapper around {@link WritableByteChannel#write(ByteBuffer)}. If
   * the amount of data is large, it writes to channel in smaller chunks. This
   * is to avoid jdk from creating many direct buffers as the size of buffer
   * increases. This also minimizes extra copies in NIO layer as a result of
   * multiple write operations required to write a large buffer.
   *
   * @see WritableByteChannel#write(ByteBuffer)
   *
   * @param ctx
   * @param buffer
   */
  private void channelWrite(ChannelHandlerContext ctx, ByteBuffer buffer) {
    try {
      ByteBuf buf = ctx.alloc().buffer();
      buf.writeBytes(buffer.array());
      ctx.writeAndFlush(buf);
    } catch (Throwable e) {
      e.printStackTrace();
    }
  }

  /** A call queued for handling. */
  private static class Call {
    private int id; // the client's call id
    private Writable param; // the parameter passed
    private ChannelInboundHandlerAdapter connection; // connection to client
    private long timestamp; // the time received when response is null
    // the time served when response is not null
    private ByteBuffer response; // the response for this call

    /**
     *
     * @param id
     * @param param
     * @param connection
     */
    public Call(int id, Writable param, ChannelInboundHandlerAdapter connection) {
      this.id = id;
      this.param = param;
      this.connection = connection;
      this.timestamp = System.currentTimeMillis();
      this.response = null;
    }

    /**
     *
     */
    @Override
    public String toString() {
      return param.toString() + " from " + connection.toString();
    }

    /**
     *
     * @param response
     */
    public void setResponse(ByteBuffer response) {
      this.response = response;
    }
  }

  /**
   * ExceptionsHandler manages Exception groups for special handling e.g., terse
   * exception group for concise logging messages
   */
  static class ExceptionsHandler {
    private volatile Set<String> terseExceptions = new HashSet<String>();

    /**
     * Add exception class so server won't log its stack trace. Modifying the
     * terseException through this method is thread safe.
     *
     * @param exceptionClass exception classes
     */
    void addTerseExceptions(Class<?>... exceptionClass) {

      // Make a copy of terseException for performing modification
      final HashSet<String> newSet = new HashSet<String>(terseExceptions);

      // Add all class names into the HashSet
      for (Class<?> name : exceptionClass) {
        newSet.add(name.toString());
      }
      // Replace terseException set
      terseExceptions = Collections.unmodifiableSet(newSet);
    }

    /**
     *
     * @param t
     * @return
     */
    boolean isTerse(Class<?> t) {
      return terseExceptions.contains(t.toString());
    }
  }

  /**
   * Called for each call.
   *
   * @param protocol
   * @param param
   * @param receiveTime
   * @return Writable
   * @throws IOException
   */
  public abstract Writable call(Class<?> protocol, Writable param,
      long receiveTime) throws IOException;
}
TOP

Related Classes of org.apache.hama.ipc.AsyncServer$Call

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.