Package com.starlight.intrepid

Source Code of com.starlight.intrepid.RemoteCallHandler$InvokeNotAckedFlag

// Copyright (c) 2010 Rob Eden.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//     * Redistributions of source code must retain the above copyright
//       notice, this list of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above copyright
//       notice, this list of conditions and the following disclaimer in the
//       documentation and/or other materials provided with the distribution.
//     * Neither the name of Intrepid nor the
//       names of its contributors may be used to endorse or promote products
//       derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
// DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

package com.starlight.intrepid;

import com.starlight.NotNull;
import com.starlight.ValidationKit;
import com.starlight.intrepid.auth.*;
import com.starlight.intrepid.exception.*;
import com.starlight.intrepid.message.*;
import com.starlight.intrepid.spi.*;
import com.starlight.listeners.ListenerSupport;
import com.starlight.locale.FormattedTextResourceKey;
import com.starlight.thread.ObjectSlot;
import com.starlight.thread.ScheduledExecutor;
import com.starlight.thread.SharedThreadPool;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.TShortObjectMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TShortObjectHashMap;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.procedure.TObjectProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;


/**
*
*/
class RemoteCallHandler implements InboundMessageHandler {
  private static final Logger LOG = LoggerFactory.getLogger( RemoteCallHandler.class );


  private static final int MAX_CHANNEL_MESSAGE_DATA_SIZE =
    Integer.getInteger( "intrepid.channel.max_message_data",
    ( 1 << 18 ) - 20 ).intValue()// 256K - 20 bytes

  private static final byte REQUEST_INVOKE_ACK_RATE_SEC = ( byte )
    Math.min( 10,
      Integer.getInteger( "intrepid.req_invoke_ack_rate_sec", 10 ).intValue() );
  static {
    if ( REQUEST_INVOKE_ACK_RATE_SEC < 0 ) {
      throw new IllegalArgumentException(
        "System property \"intrepid.req_invoke_ack_rate_sec\" may not be set " +
        "to a negative number." );
    }
  }

  private static final InvokeCloseFlag SESSION_CLOSED_FLAG = new InvokeCloseFlag();

  private static final InvokeNotAckedFlag NEVER_ACKED_FLAG =
    new InvokeNotAckedFlag( false );
  private static final InvokeNotAckedFlag SUBSEQUENT_NOT_ACKED_FLAG =
    new InvokeNotAckedFlag( true );



  private enum InvokeAttempt { BY_ID, BY_PERSISTENT_NAME }
  private static final InvokeAttempt[] INVOKE_ATTEMPT = InvokeAttempt.values();



  private final AuthenticationHandler auth_handler;
  private final IntrepidSPI spi;
  private final LocalCallHandler local_handler;
  private final VMID local_vmid;
  private final ScheduledExecutor executor;
  private final ChannelAcceptor channel_acceptor;
  private Intrepid instance;    // for call context info

  private final ListenerSupport<PerformanceListener,?> performance_listeners;

  private final AtomicInteger call_id_counter = new AtomicInteger( 0 );

  private final Lock call_wait_map_lock = new ReentrantLock();
  private final TIntObjectMap<CallInfoAndAckControl> call_wait_map =
    new TIntObjectHashMap<CallInfoAndAckControl>();

  // Map of VMID to active calls
  private final Map<VMID,TIntSet> vmid_call_wait_map = new HashMap<VMID,TIntSet>();


  private final Lock ping_wait_map_lock = new ReentrantLock();
  private final TShortObjectHashMap<ObjectSlot<PingResponseIMessage>> ping_wait_map =
    new TShortObjectHashMap<ObjectSlot<PingResponseIMessage>>();


  private final Lock channel_map_lock = new ReentrantLock();
  private final TIntObjectMap<ObjectSlot<ChannelInitResponseIMessage>> channel_init_wait_map =
    new TIntObjectHashMap<ObjectSlot<ChannelInitResponseIMessage>>();

  // Map of active virtual channels
  private final Map<VMID,TShortObjectMap<VirtualByteChannel>> channel_map =
    new HashMap<VMID,TShortObjectMap<VirtualByteChannel>>();

  private final AtomicInteger channel_id_counter = new AtomicInteger();


  private final TIntObjectMap<InvokeRunner> runner_map =
    new TIntObjectHashMap<InvokeRunner>();
  private final Lock runner_map_lock = new ReentrantLock();

  RemoteCallHandler( IntrepidSPI spi, AuthenticationHandler auth_handler,
    LocalCallHandler local_handler, VMID local_vmid, ScheduledExecutor executor,
    ListenerSupport<PerformanceListener, ?> performance_listeners,
    ChannelAcceptor channel_acceptor ) {

    this.auth_handler = auth_handler;
    this.spi = spi;
    this.local_handler = local_handler;
    this.local_vmid = local_vmid;
    this.executor = executor;
    this.channel_acceptor = channel_acceptor;
    this.performance_listeners = performance_listeners;
  }

  void initInstance( Intrepid instance ) {
    this.instance = instance;
  }


  public Object invoke( VMID vmid, int object_id, int method_id, String persistent_name,
    Object[] args, Method method ) throws Throwable {

    int call_id = call_id_counter.getAndIncrement();

    final ObjectSlot<InvokeReturnIMessage> return_slot =
      new ObjectSlot<InvokeReturnIMessage>();

    // Make sure args are serializable and wrap in proxy if they're not
    if ( args != null && args.length > 0 ) checkArgsForSerialization( args );

    UserContextInfo user_context = null;
    if ( IntrepidContext.isCall() ) {
      user_context = IntrepidContext.getUserInfo();
    }

    final boolean has_perf_listeners = performance_listeners.hasListeners();

    if ( has_perf_listeners ) {
      performance_listeners.dispatch().remoteCallStarted( local_vmid,
        System.currentTimeMillis(), call_id, vmid, object_id, method_id, method,
        args, user_context, null );
    }

    InvokeIMessage message = new InvokeIMessage( call_id, object_id, null, method_id,
      args, user_context, has_perf_listeners );
    AtomicInteger protocol_version_slot = new AtomicInteger();

    CallInfoAndAckControl call_info =
      new CallInfoAndAckControl( executor, return_slot );

    // Put the slot in the call_wait_map
    call_wait_map_lock.lock();
    try {
      call_wait_map.put( call_id, call_info );

      TIntSet call_id_set = vmid_call_wait_map.get( vmid );
      if ( call_id_set == null ) {
        call_id_set = new TIntHashSet();
        vmid_call_wait_map.put( vmid, call_id_set );
      }
      call_id_set.add( call_id );
    }
    finally {
      call_wait_map_lock.unlock();
    }

    try {
      VMID new_vmid = null;
      Integer new_object_id = null;
      Throwable t = null;
      for( InvokeAttempt attempt : INVOKE_ATTEMPT ) {
        if ( attempt == InvokeAttempt.BY_PERSISTENT_NAME ) {
          if ( persistent_name == null ) break;
         
          if ( has_perf_listeners ) {
            performance_listeners.dispatch().remoteCallStarted( local_vmid,
              System.currentTimeMillis(), call_id, vmid, object_id,
              method_id, method, args, user_context, persistent_name );
          }

          message = new InvokeIMessage( call_id, object_id, persistent_name,
            method_id, args, user_context,
            performance_listeners.hasListeners() );
        }

        Integer call_id_obj = Integer.valueOf( call_id );
        LOG.debug( "Sending message {}", call_id_obj );
        assert message != null;
        assert vmid != null;

        SessionInfo session_info =
          spi.sendMessage( vmid, message, protocol_version_slot );
        LOG.debug( "Message sent: {}", call_id_obj );

        VMID current_vmid = session_info.getVMID();
        if ( current_vmid != null && !current_vmid.equals( vmid ) ) {
          new_vmid = current_vmid;
        }


        // If the connection supports method acks, set up timer to expect it.
        if ( ProtocolVersions.supportsMethodAck(
          protocol_version_slot.byteValue() ) ) {

          Byte session_ack_rate = session_info.getAckRateSec();
          if ( session_ack_rate != null ) {
            call_info.setSessionAckRateSec( session_ack_rate.byteValue() );
          }

          call_info.scheduleAckExpector( true );
        }


        InvokeReturnIMessage return_message;
        try {
          // Wait for the response
          LOG.debug( "Waiting for slot: {}", call_id_obj );
          return_message = return_slot.waitForValue();
        }
        catch( InterruptedException ex ) {
          // send interrupt message
          try {
            spi.sendMessage( vmid,
              new InvokeInterruptIMessage( call_id ), null );
          }
          catch( Exception exc ) {
            // ignore
          }
          throw new InterruptedCallException( ex );
        }

        if ( has_perf_listeners ) {
          performance_listeners.dispatch().remoteCallCompleted( local_vmid,
            System.currentTimeMillis(), call_id, return_message.getValue(),
            return_message.isThrown(), return_message.getServerTimeNano() );
        }

        new_object_id = return_message.getNewObjectID();

        if ( return_message.isThrown() ) {
          t = ( Throwable ) return_message.getValue();

          // Look for an indicator that the object wasn't found. If that's the
          // case and we have a persistent name available (and we haven't tried
          // already), then retry the loop which will retry with the persistent
          // name.
          if ( t instanceof UnknownObjectException &&
            attempt == InvokeAttempt.BY_ID &&
            persistent_name != null ) {

            return_slot.clear();
            //noinspection UnnecessaryContinue
            continue;
          }
          else break;
        }
        else {
          // If there's a new VMID or ObjectID, pass it on
          if ( new_vmid != null || new_object_id != null ) {
            throw new NewIDIndicator( new_vmid,
              new_object_id == null ? object_id : new_object_id.intValue(),
              return_message.getValue(), false );
          }
          else return return_message.getValue();
        }
      }

      if ( !( t instanceof InterruptedCallException ) && t instanceof Error ) {
        t = new ServerException( t );
      }

      // If there's a new VMID or ObjectID, pass it on
      if ( new_vmid != null || new_object_id != null ) {
        throw new NewIDIndicator( new_vmid,
          new_object_id == null ? object_id : new_object_id.intValue(),
          t, true );
      }
      else throw t;
    }
    finally {
      call_info.dispose();

      // Make sure nothing is left in the call_wait_map
      call_wait_map_lock.lock();
      try {
        call_wait_map.remove( call_id );

        TIntSet call_id_set = vmid_call_wait_map.get( vmid );
        if ( call_id_set != null ) call_id_set.remove( call_id );
      }
      finally {
        call_wait_map_lock.unlock();
      }
    }
  }


  ByteChannel channelCreate( VMID destination, Serializable attachment )
    throws IOException, ChannelRejectedException {

    final int call_id = call_id_counter.getAndIncrement();

    ObjectSlot<ChannelInitResponseIMessage> response_slot =
      new ObjectSlot<ChannelInitResponseIMessage>();
    boolean successful = false;
    short channel_id = -1;
    try {
      ChannelInitIMessage channel_init;
      VirtualByteChannel channel;

      channel_map_lock.lock();
      try {
        // Prepare for the response with the call ID
        channel_init_wait_map.put( call_id, response_slot );

       
        TShortObjectMap<VirtualByteChannel> channel_id_map =
          channel_map.get( destination );
        if ( channel_id_map == null ) {
          channel_id_map = new TShortObjectHashMap<VirtualByteChannel>();
          channel_map.put( destination, channel_id_map );
        }

        // Determine a free channel ID 
        while( true ) {
          int channel_id_tmp = channel_id_counter.getAndIncrement();
          if ( channel_id_tmp > Short.MAX_VALUE ) {
            channel_id_counter.set( 0 );
            channel_id_tmp = 0;
          }
 
          if ( channel_id_map.containsKey( ( short ) channel_id_tmp ) ) continue;
 
          channel_id = ( short ) channel_id_tmp;
          break;
        }
       
        channel_init = new ChannelInitIMessage( call_id, attachment, channel_id );
        channel = new VirtualByteChannel( destination, channel_id, this );
       
        // Register the channel so we're ready to receive data immediately
        channel_id_map.put( channel_id, channel );
      }
      finally {
        channel_map_lock.unlock();
      }

      spi.sendMessage( destination, channel_init, null );

     
      ChannelInitResponseIMessage response = response_slot.waitForValue();

      if ( response.isSuccessful() ) {
        if ( performance_listeners.hasListeners() ) {
          performance_listeners.dispatch().virtualChannelOpened(
            local_vmid, destination, channel_id );
        }
        successful = true;

        return channel;
      }
      else throw new ChannelRejectedException( response.getRejectReason() );
    }
    catch( InterruptedException ex ) {
      throw new InterruptedIOException();
    }
    finally {
      if ( !successful ) {
        // Make sure nothing is left in the channel maps
        channel_map_lock.lock();
        try {
          channel_init_wait_map.remove( call_id );
         
          TShortObjectMap<VirtualByteChannel> channel_id_map =
            channel_map.get( destination );
          if ( channel_id_map != null && channel_id != -1 ) {
            channel_id_map.remove( channel_id );
            if ( channel_id_map.isEmpty() ) channel_map.remove( destination );
          }
        }
        finally {
          channel_map_lock.unlock();
        }
      }
    }
  }


  void channelClose( final VMID destination, final short channel_id,
    boolean send_message ) {

    if ( send_message ) {
      try {
        spi.sendMessage( destination,
          new ChannelCloseIMessage( channel_id ), null );
      }
      catch( NotConnectedException ex ) {
        // ignore
      }
      catch ( IOException ex ) {
        LOG.warn( "Unable to send channel close message: {}",
          Short.valueOf( channel_id ), ex );
      }
    }

    channel_map_lock.lock();
    try {
      TShortObjectMap<VirtualByteChannel> channel_id_map =
        channel_map.get( destination );
      if ( channel_id_map != null ) {
        VirtualByteChannel channel = channel_id_map.remove( channel_id );
        if ( channel != null ) channel.closedByPeer( false );
       
        if ( channel_id_map.isEmpty() ) channel_map.remove( destination );
      }
    }
    finally {
      channel_map_lock.unlock();
    }

    if ( performance_listeners.hasListeners() ) {
      performance_listeners.dispatch().virtualChannelClosed( local_vmid,
        destination, channel_id );
    }
  }

  void channelSendData( VMID destination, short channel_id, ByteBuffer data )
    throws IOException {

    if ( !data.hasRemaining() ) return;

    int bytes = data.remaining();

    int original_limit = data.limit();
    while( data.position() < original_limit ) {
      data.limit( Math.min( original_limit,
        data.position() + MAX_CHANNEL_MESSAGE_DATA_SIZE ) );
      spi.sendMessage( destination,
        new ChannelDataIMessage( channel_id, data ), null );
    }

    if ( performance_listeners.hasListeners() ) {
      performance_listeners.dispatch().virtualChannelDataSent( local_vmid,
        destination, channel_id, bytes );
    }
  }

  long ping( VMID vmid, long timeout, TimeUnit timeout_unit )
    throws TimeoutException, IntrepidRuntimeException, InterruptedException {

    short id = ( short ) call_id_counter.getAndIncrement();

    ObjectSlot<PingResponseIMessage> response_slot =
      new ObjectSlot<PingResponseIMessage>();

    ping_wait_map_lock.lock();
    try {
      ping_wait_map.put( id, response_slot );
    }
    finally {
      ping_wait_map_lock.unlock();
    }

    long start = System.nanoTime();

    try {
      spi.sendMessage( vmid, new PingIMessage( id ), null );

      PingResponseIMessage response =
        response_slot.waitForValue( timeout_unit.toMillis( timeout ) );
      if ( response == null ) throw new TimeoutException();
      else return TimeUnit.NANOSECONDS.toMillis( System.nanoTime() - start );
    }
    catch( IOException ex ) {
      throw new IntrepidRuntimeException( ex );
    }
    finally {
      ping_wait_map_lock.lock();
      try {
        ping_wait_map.remove( id );
      }
      finally {
        ping_wait_map_lock.unlock();
      }
    }
  }



  @Override
  public IMessage receivedMessage( SessionInfo session_info, IMessage message )
    throws CloseSessionIndicator {

    LOG.debug( "receivedMessage from {}: {}", session_info.getVMID(), message );

    IMessage response = null;
    switch ( message.getType() ) {
      case SESSION_INIT:
        response = handleSessionInit( ( SessionInitIMessage ) message,
          session_info );
        break;

      case SESSION_INIT_RESPONSE:
        handleSessionInitResponse( ( SessionInitResponseIMessage ) message,
          session_info );
        break;

      case SESSION_TOKEN_CHANGE:
        handleSessionTokenChange( ( SessionTokenChangeIMessage ) message,
          session_info );
        break;

      case SESSION_CLOSE:
        handleSessionClose( ( SessionCloseIMessage ) message );
        break;

      case INVOKE:
        handleInvoke( ( InvokeIMessage ) message, session_info );
        break;

      case INVOKE_RETURN:
        handleInvokeReturn( ( InvokeReturnIMessage ) message );
        break;

      case INVOKE_INTERRUPT:
        handleInvokeInterrupt( ( InvokeInterruptIMessage ) message );
        break;

      case INVOKE_ACK:
        handlerInvokeAck( ( InvokeAckIMessage ) message, session_info.getVMID() );
        break;

      case LEASE:
        LeaseManager.handleLease( ( LeaseIMessage ) message, local_handler,
          session_info.getVMID() );
        break;

      case LEASE_RELEASE:
        LeaseManager.handleLeaseRelease( ( LeaseReleaseIMessage ) message,
          local_handler, session_info.getVMID() );
        break;

      case CHANNEL_INIT:
        response = handleChannelInit( ( ChannelInitIMessage ) message,
          session_info.getVMID() );
        break;

      case CHANNEL_INIT_RESPONSE:
        handleChannelInitResponse( ( ChannelInitResponseIMessage ) message );
        break;

      case CHANNEL_DATA:
        handleChannelData( ( ChannelDataIMessage ) message, session_info.getVMID() );
        break;

      case CHANNEL_CLOSE:
        handleChannelClose( ( ChannelCloseIMessage ) message,
          session_info.getVMID() );
        break;

      case PING:
        response = handlePing( ( PingIMessage ) message );
        break;

      case PING_RESPONSE:
        handlePingResponse( ( PingResponseIMessage ) message );
        break;

      default:
        assert false : "Unknown type: " + message.getType();
        throw new CloseSessionIndicator( new SessionCloseIMessage(
          new FormattedTextResourceKey( Resources.UNKNOWN_MESSAGE_TYPE,
          message.getType().name() ), false ) );
    }

    return response;
  }

  @Override
  public boolean sessionClosed( SessionInfo session_info, boolean opened_locally,
    boolean closed_locally, boolean can_reconnect ) {

    VMID vmid = session_info.getVMID();
    if ( vmid != null ) {
      // Close active method calls
      call_wait_map_lock.lock();
      try {
        TIntSet call_id_set = vmid_call_wait_map.get( vmid );
        if ( call_id_set != null && !call_id_set.isEmpty() ) {
          call_id_set.forEach( new CallInterruptProcedure() );
        }
      }
      finally {
        call_wait_map_lock.unlock();
      }

      // Close active virtual channels
      TShortObjectMap<VirtualByteChannel> channel_id_map;
      channel_map_lock.lock();
      try {
        channel_id_map = channel_map.remove( vmid );
      }
      finally {
        channel_map_lock.unlock();
      }

      if ( channel_id_map != null ) {
        channel_id_map.forEachValue( new ChannelCloseProcedure() );
      }
    }

    // Attempt reconnection if:
    //  1) It's possible
    //  2) We initiated the connection originally
    //  3) We didn't close the connection
    return can_reconnect && opened_locally && !closed_locally;
  }


  @Override
  public IMessage sessionOpened( SessionInfo session_info, boolean opened_locally,
    ConnectionArgs connection_args ) throws CloseSessionIndicator {

    // Only care about sessions we opened
    if ( !opened_locally ) return null;

    return new SessionInitIMessage( local_vmid, spi.getServerPort(),
      connection_args, ProtocolVersions.MIN_PROTOCOL_VERSION,
      ProtocolVersions.PROTOCOL_VERSION, session_info.getReconnectToken(),
      REQUEST_INVOKE_ACK_RATE_SEC );
  }


  private IMessage handleSessionInit( SessionInitIMessage message,
    SessionInfo session_info ) throws CloseSessionIndicator {

    if ( auth_handler == null ) {
      throw new CloseSessionIndicator( new SessionCloseIMessage(
        Resources.ERROR_CLIENT_CONNECTIONS_NOT_ALLOWED_NO_AUTH_HANDLER, true ) );
    }

    byte proto_version = ProtocolVersions.negotiateProtocolVersion(
      message.getMinProtocolVersion(), message.getPrefProtocolVersion() );
    if ( proto_version < 0 ) {
      throw new CloseSessionIndicator( new SessionCloseIMessage(
        new FormattedTextResourceKey( Resources.INCOMPATIBLE_PROTOCOL_VERSION,
        Byte.valueOf( ProtocolVersions.MIN_PROTOCOL_VERSION ),
        Byte.valueOf( ProtocolVersions.PROTOCOL_VERSION ),
        Byte.valueOf( message.getMinProtocolVersion() ),
        Byte.valueOf( message.getPrefProtocolVersion() ) ), false ) );
    }
   
    // Handle session re-init (see class docs on RequestUserCredentialReinit)
    if ( message.getConnectionArgs() != null &&
      message.getConnectionArgs() instanceof RequestUserCredentialReinit ) {

      if ( auth_handler instanceof UserCredentialReinitAuthenticationHandler ) {
        UserCredentialReinitAuthenticationHandler handler =
          ( UserCredentialReinitAuthenticationHandler ) auth_handler;

        try {
          UserCredentialsConnectionArgs new_args = handler.getUserCredentials(
            session_info.getRemoteAddress(), session_info.getSessionSource() );
          return new SessionInitIMessage( local_vmid, spi.getServerPort(),
            new_args, ProtocolVersions.MIN_PROTOCOL_VERSION,
            ProtocolVersions.PROTOCOL_VERSION,
            session_info.getReconnectToken(), REQUEST_INVOKE_ACK_RATE_SEC );
        }
        catch( ConnectionAuthFailureException ex ) {
          throw new CloseSessionIndicator(
            new SessionCloseIMessage( ex.getMessageResourceKey(), true ) );
        }
      }
      else {
        throw new CloseSessionIndicator( new SessionCloseIMessage(
          Resources.ERROR_USER_REINIT_CONNECTIONS_NOT_ALLOWED, true ) );
      }
    }

    // "Normal" connection...
    UserContextInfo user_context;
    Serializable reconnect_token = null;
    try {
      // Session token reconnection added in protocol version 1.
      if ( ProtocolVersions.supportsReconnectTokens( proto_version ) &&
        auth_handler instanceof TokenReconnectAuthenticationHandler ) {

        TokenReconnectAuthenticationHandler token_handler =
          ( TokenReconnectAuthenticationHandler ) auth_handler;
        user_context = token_handler.checkConnection( message.getConnectionArgs(),
          session_info.getRemoteAddress(), session_info.getSessionSource(),
          message.getReconnectToken() );

        reconnect_token = generateReconnectToken( session_info, token_handler,
          message.getInitiatorVMID(), user_context, message.getConnectionArgs(),
          session_info.getRemoteAddress(), session_info.getSessionSource(),
          message.getReconnectToken(), false );
      }
      else {
        user_context = auth_handler.checkConnection( message.getConnectionArgs(),
          session_info.getRemoteAddress(), session_info.getSessionSource() );
      }
    }
    catch( ConnectionAuthFailureException ex ) {
      throw new CloseSessionIndicator(
        new SessionCloseIMessage( ex.getMessageResourceKey(), true ) );
    }

    session_info.setProtocolVersion( Byte.valueOf( proto_version ) );
    // NOTE: set user context first, so it's available when the connectionOpened
    //       message is fired.
    session_info.setUserContext( user_context );
    // NOTE: MUST come before setVMID
    session_info.setPeerServerPort( message.getInitiatorServerPort() );


    byte ack_rate = REQUEST_INVOKE_ACK_RATE_SEC;
    if ( message.getRequestedAckRateSec() > 0 &&
      message.getRequestedAckRateSec() < ack_rate ) {

      ack_rate = message.getRequestedAckRateSec();
    }

    session_info.setVMID( message.getInitiatorVMID(), ack_rate );

    return new SessionInitResponseIMessage( local_vmid, spi.getServerPort(),
      proto_version, reconnect_token, ack_rate );
  }

  private void handleSessionInitResponse( SessionInitResponseIMessage message,
    SessionInfo session_info ) {

    byte ack_rate = message.getAckRateSec();
    if ( ack_rate <= 0 ) ack_rate = REQUEST_INVOKE_ACK_RATE_SEC;

    session_info.setProtocolVersion( Byte.valueOf( message.getProtocolVersion() ) );
    session_info.setVMID( message.getResponderVMID(), ack_rate );
    session_info.setReconnectToken( message.getReconnectToken() );
  }

  private void handleSessionTokenChange( SessionTokenChangeIMessage message,
    SessionInfo session_info ) {

    session_info.setReconnectToken( message.getNewReconnectToken() );

    if ( LOG.isDebugEnabled() ) {
      LOG.debug( "Session reconnect token changed for connection to {}: {}",
        session_info.getVMID(), message.getNewReconnectToken() );
    }
  }

  private void handleSessionClose( SessionCloseIMessage message )
    throws CloseSessionIndicator {

    if ( auth_handler instanceof UserCredentialReinitAuthenticationHandler &&
      message.isAuthFailure() ) {

      UserCredentialReinitAuthenticationHandler handler =
        ( UserCredentialReinitAuthenticationHandler ) auth_handler;
      handler.notifyUserCredentialFailure( message.getReason() );
    }

//    System.out.println( "Notified of close for session with " + info.getVMID() +
//      ": " + message.getReason() );
    throw new CloseSessionIndicator( message.getReason() );
  }

  private void handleInvoke( InvokeIMessage message,
    SessionInfo session_info ) {

    LOG.trace( "Invoke: {} (protocol version={})", message,
      session_info.getProtocolVersion() );

    // Use the messages call context, unless a context has already been set for the
    // session. In other words, only allow an overriding context if this is a server
    // connection.
    UserContextInfo context_info = session_info.getUserContext();
    if ( context_info == null ) context_info = message.getUserContext();

    if ( performance_listeners.hasListeners() ) {
      performance_listeners.dispatch().inboundRemoteCallStarted( local_vmid,
        System.currentTimeMillis(), message.getCallID(), session_info.getVMID(),
        message.getObjectID(), message.getMethodID(),
        local_handler.lookupMethodForID( message.getObjectID(),
          message.getMethodID() ),
        message.getArgs(), context_info, message.getPersistentName() );
    }

    InetAddress source_address = null;
    SocketAddress sock_addr = session_info.getRemoteAddress();
    if ( sock_addr != null && sock_addr instanceof InetSocketAddress ) {
      source_address = ( ( InetSocketAddress ) sock_addr ).getAddress();
    }

    Byte protocol_version = session_info.getProtocolVersion();
    assert protocol_version != null :
      "Unknown protocol version for session: " + session_info;
    //noinspection ConstantConditions
    boolean needs_ack = protocol_version == null ||
      ProtocolVersions.supportsMethodAck( protocol_version.byteValue() );

    Byte ack_rate = session_info.getAckRateSec();

    InvokeRunner runner = new InvokeRunner( message, session_info.getVMID(),
      source_address, context_info, spi, local_handler, instance, runner_map,
      runner_map_lock, performance_listeners, needs_ack,
      ack_rate == null ? REQUEST_INVOKE_ACK_RATE_SEC : ack_rate.byteValue(),
      executor );

    runner_map_lock.lock();
    try {
      runner_map.put( message.getCallID(), runner );
    }
    finally {
      runner_map_lock.unlock();
    }

    executor.execute( runner );
  }

  private void handleInvokeReturn( InvokeReturnIMessage message ) {
    CallInfoAndAckControl call_info;

    LOG.trace( "Invoke return: {}", message );

    call_wait_map_lock.lock();
    try {
      call_info = call_wait_map.get( message.getCallID() );
    }
    finally {
      call_wait_map_lock.unlock();
    }

    if ( call_info == null ) {
      if ( LOG.isDebugEnabled() ) {
        LOG.debug( "No info found for call {}, message: {}",
          Integer.valueOf( message.getCallID() ), message );
      }
      return;
    }

    call_info.return_slot.set( message );
  }

  private void handleInvokeInterrupt( InvokeInterruptIMessage message ) {
    LOG.trace( "Invoke interrupt: {}", message );

    InvokeRunner runner;
    runner_map_lock.lock();
    try {
      runner = runner_map.get( message.getCallID() );
    }
    finally {
      runner_map_lock.unlock();
    }

    if ( runner != null ) runner.interrupt();
  }


  private void handlerInvokeAck( InvokeAckIMessage message, VMID caller_vmid ) {
    LOG.trace( "Invoke ack from {}: {}", caller_vmid, message );

    CallInfoAndAckControl call_info;
    call_wait_map_lock.lock();
    try {
      call_info = call_wait_map.get( message.getCallID() );
    }
    finally {
      call_wait_map_lock.unlock();
    }

    if ( call_info == null ) {
      if ( LOG.isDebugEnabled() ) {
        LOG.debug( "No info found for call {}, message: {}",
          Integer.valueOf( message.getCallID() ), message );
      }
      return;
    }

    call_info.scheduleAckExpector( false );
  }


  private ChannelInitResponseIMessage handleChannelInit( ChannelInitIMessage message,
    VMID vmid ) {

    if ( channel_acceptor == null ) {
      return new ChannelInitResponseIMessage( message.getRequestID(), null );
    }

    VirtualByteChannel channel;
    final short channel_id = message.getChannelID();

    channel_map_lock.lock();
    try {
      TShortObjectMap<VirtualByteChannel> channel_id_map = channel_map.get( vmid );
      if ( channel_id_map == null ) {
        channel_id_map = new TShortObjectHashMap<VirtualByteChannel>();
        channel_map.put( vmid, channel_id_map );
      }

      channel = new VirtualByteChannel( vmid, message.getChannelID(), this );

      VirtualByteChannel prev_channel = channel_id_map.put( channel_id, channel );
      if ( prev_channel != null ) {
        assert false : "Duplicate channel ID " + channel_id;
        LOG.warn( "Duplicate channel ID: {}", Short.valueOf( channel_id ) );
        prev_channel.closedByPeer( true );
      }
    }
    finally {
      channel_map_lock.unlock();
    }

    try {
      // See if the channel is approved
      channel_acceptor.newChannel( channel, vmid, message.getAttachment() );
    }
    catch( ChannelRejectedException ex ) {
      // Don't close the channel directly, but call our close method (without
      // sending a message) to clean up internal maps.
      channelClose( vmid, channel_id, false );

      // NOTE: no need to close VirtualByteChannel
      return new ChannelInitResponseIMessage( message.getRequestID(),
        ex.getMessageResourceKey() );
    }

    if ( performance_listeners.hasListeners() ) {
      performance_listeners.dispatch().virtualChannelOpened( local_vmid,
        vmid, channel_id );
    }

    return new ChannelInitResponseIMessage( message.getRequestID() );
  }

  private void handleChannelInitResponse( ChannelInitResponseIMessage message ) {
    channel_map_lock.lock();
    try {
      ObjectSlot<ChannelInitResponseIMessage> response_slot =
        channel_init_wait_map.get( message.getRequestID() );

      if ( response_slot != null ) response_slot.set( message );
    }
    finally {
      channel_init_wait_map.remove( message.getRequestID() );

      channel_map_lock.unlock();
    }
  }

  private void handleChannelData( ChannelDataIMessage message, VMID vmid ) {
    VirtualByteChannel channel = null;
    channel_map_lock.lock();
    try {
      TShortObjectMap<VirtualByteChannel> channel_id_map = channel_map.get( vmid );
      if ( channel_id_map != null ) {
        channel = channel_id_map.get( message.getChannelID() );
      }
    }
    finally {
      channel_map_lock.unlock();
    }

    if ( channel == null ) return;      // Log? Send error?

    // NOTE: VirtualByteChannel currently hangs on to these buffers
    final int buffers = message.getBufferCount();
    int bytes = 0;
    for( int i = 0; i < buffers; i++ ) {
      ByteBuffer buffer = message.getBuffer( i );
      bytes += buffer.remaining();
      channel.putData( buffer );
    }

    if ( performance_listeners.hasListeners() ) {
      performance_listeners.dispatch().virtualChannelDataReceived( local_vmid,
        vmid, message.getChannelID(), bytes );
    }
  }

  private void handleChannelClose( ChannelCloseIMessage message, VMID vmid ) {
    channelClose( vmid, message.getChannelID(), false );
  }

  private PingResponseIMessage handlePing( PingIMessage message ) {
    return new PingResponseIMessage( message.getSequenceNumber() );
  }

  private void handlePingResponse( PingResponseIMessage message ) {
    ping_wait_map_lock.lock();
    try {
      ObjectSlot<PingResponseIMessage> slot =
        ping_wait_map.get( message.getSequenceNumber() );
      if ( slot != null ) slot.set( message );
    }
    finally {
      ping_wait_map_lock.unlock();
    }
  }


  private Serializable generateReconnectToken( SessionInfo session_info,
    TokenReconnectAuthenticationHandler token_handler, VMID vmid,
    UserContextInfo user_context, ConnectionArgs connection_args,
    SocketAddress remote_address, Object session_source,
    Serializable previous_reconnect_token, boolean send_token_change_message ) {


    Serializable reconnect_token = token_handler.generateReconnectToken( user_context,
      connection_args, remote_address, session_source, previous_reconnect_token );

    int regeneration_interval_sec =
      token_handler.getTokenRegenerationInterval();

    TokenRegenerator regenerator = new TokenRegenerator( session_info, token_handler,
      vmid, user_context, connection_args, remote_address, session_source,
      reconnect_token );

    ScheduledFuture<?> future =
      SharedThreadPool.INSTANCE.schedule( regenerator, regeneration_interval_sec,
      TimeUnit.SECONDS );
    session_info.setReconnectTokenRegenerationTimer( future );

    if ( send_token_change_message ) {
      for( int i = 2; i >= 0; i-- ) {
        try {
          spi.sendMessage( vmid,
            new SessionTokenChangeIMessage( reconnect_token ), null );
          break;
        }
        catch ( IOException e ) {
          String message = "Unable to send SessionTokenChange message";
          if ( i == 0 ) message += " (will NOT retry)";
          else message += " (will retry)";

          LOG.warn( message, e );
        }
      }
    }

    return reconnect_token;
  }


  private void checkArgsForSerialization( Object[] args ) {
    Class array_type = null;

    for( int i = 0; i < args.length; i++ ) {
      Object arg = args[ i ];
      if ( arg == null ) continue;

      if ( !( arg instanceof ForceProxy ) && arg instanceof Serializable ) continue;

      // Figure out the array type if we haven't already
      if ( array_type == null ) {
        array_type = args.getClass().getComponentType();
      }

      // Try to wrap the object in a proxy
      try {
        Proxy proxy = local_handler.createProxy( arg, null );
        if ( array_type != null && proxy != null &&
          array_type.isAssignableFrom( proxy.getClass() ) ) {
         
//          System.out.println( "Replacing argument " + i + " (" + args[ i ] +
//            ") with: " + proxy );
          args[ i ] = proxy;
        }
      }
      catch( IllegalProxyDelegateException ex ) {
        // skip the switch-out
      }
    }
  }


  Registry getRemoteRegistry( VMID vmid ) {
    if ( !spi.hasConnection( vmid ) ) {
      throw new NotConnectedException( vmid );
    }

    TObjectIntMap<MethodIDTemplate> method_map = MethodMap.generateReverseMethodMap(
      MethodMap.generateMethodMap( Registry.class ) );

    Object proxy = java.lang.reflect.Proxy.newProxyInstance(
      Registry.class.getClassLoader(), new Class[] { Proxy.class, Registry.class },
      new ProxyInvocationHandler( vmid, 0, method_map, null, null, local_vmid ) );
    return ( Registry ) proxy;
  }


  private class CallInterruptProcedure implements TIntProcedure {
    @Override
    public boolean execute( int call_id ) {
      CallInfoAndAckControl call_info = call_wait_map.get( call_id );
      call_info.return_slot.set( SESSION_CLOSED_FLAG );
      return true;
    }
  }


  private class ChannelCloseProcedure implements TObjectProcedure<VirtualByteChannel> {
    @Override
    public boolean execute( VirtualByteChannel channel ) {
      if ( channel != null ) channel.closedByPeer( true );
      return true;
    }
  }


  private static class InvokeCloseFlag extends InvokeReturnIMessage {
    public InvokeCloseFlag() {
      //noinspection ThrowableResultOfMethodCallIgnored
      super( -1, buildException(), true, null, null );
    }


    private static InterruptedCallException buildException() {
      InterruptedCallException ex = new InterruptedCallException(
        "Connection to peer closed during method invocation" );

      // Erase the stack because this exception is build ahead of time and reused.
      // So, the stack is pretty pointless (and confusing) when viewed.
      ex.setStackTrace( new StackTraceElement[ 0 ] );
      return ex;
    }
  }


  /**
   * Used to signal that an invocation did not received the expected acknowlegement from
   * the server.
   */
  private static class InvokeNotAckedFlag extends InvokeReturnIMessage {

    public InvokeNotAckedFlag( boolean received_any_acks ) {
      //noinspection ThrowableResultOfMethodCallIgnored
      super( -1, buildException( received_any_acks ), true, null, null );
    }


    private static MethodInvocationFailedException buildException(
      boolean received_any_acks ) {

      String message;
      if ( received_any_acks ) {
        message = "Message acknowledgement timeout exceeded";
      }
      else {
        message = "Initial message acknowledgement timeout exceeded";
      }
      MethodInvocationFailedException ex =
        new MethodInvocationFailedException( message );

      // Erase the stack because this exception is build ahead of time and reused.
      // So, the stack is pretty pointless (and confusing) when viewed.
      ex.setStackTrace( new StackTraceElement[ 0 ] );
      return ex;
    }
  }


  private class TokenRegenerator implements Runnable {
    private final SessionInfo session_info;
    private final TokenReconnectAuthenticationHandler auth_handler;
    private final VMID vmid;
    private final UserContextInfo user_context;
    private final ConnectionArgs connection_args;
    private final SocketAddress remote_address;
    private final Object session_source;
    private final Serializable previous_reconnect_token;

    TokenRegenerator( SessionInfo session_info,
      TokenReconnectAuthenticationHandler auth_handler, VMID vmid,
      UserContextInfo user_context, ConnectionArgs connection_args,
      SocketAddress remote_address, Object session_source,
      Serializable previous_reconnect_token ) {

      this.session_info = session_info;
      this.auth_handler = auth_handler;
      this.vmid = vmid;
      this.user_context = user_context;
      this.connection_args = connection_args;
      this.remote_address = remote_address;
      this.session_source = session_source;
      this.previous_reconnect_token = previous_reconnect_token;
    }


    @Override
    public void run() {
      Thread.currentThread().setName( "Intrepid Session TokenRegenerator: " +
        remote_address );

      try {
        generateReconnectToken( session_info, auth_handler, vmid, user_context,
          connection_args, remote_address, session_source,
          previous_reconnect_token, true );
      }
      catch( Throwable t ) {
        LOG.warn( "Unexpected error regenerating reconnection token", t );
      }
    }
  }


  /**
   * This object serves two purposes:
   * <ol>
   * <li>It is the data container for information relating to a call invocation</li>
   * <li>It is the Runnable that is executed if the call is not ack'ed by the server
   *     in a proper amount of time</li>
   * </ol>
   */
  private static class CallInfoAndAckControl implements Runnable {
    long ack_deadline_ms;       // WARNING: milliseconds

    final ScheduledExecutor executor;
    final ObjectSlot<InvokeReturnIMessage> return_slot;

    private volatile boolean received_an_ack = false;
    private volatile boolean aborted_by_ack_fail = false;

    ScheduledFuture<?> ack_expect_future;
    final Lock ack_handler_lock = new ReentrantLock();


    private CallInfoAndAckControl( @NotNull ScheduledExecutor executor,
      @NotNull ObjectSlot<InvokeReturnIMessage> return_slot ) {

      ValidationKit.checkNonnull( return_slot, "return_slot" );

      this.executor = executor;
      this.return_slot = return_slot;

      // Default, should be overridden
      setSessionAckRateSec( REQUEST_INVOKE_ACK_RATE_SEC );
    }


    void setSessionAckRateSec( byte ack_sec ) {
      // WARNING: Converting to milliseconds
      ack_deadline_ms = Math.round( TimeUnit.SECONDS.toMillis( ack_sec ) * 2.5 );
    }


    void scheduleAckExpector( boolean first ) {
      if ( !first ) received_an_ack = true;

      ack_handler_lock.lock();
      try {
        // Do nothing if we've already killed the call due to missing an ack
        if ( aborted_by_ack_fail ) return;

        if ( ack_expect_future != null ) {
          ack_expect_future.cancel( false );
        }

        // NOTE: time is in milliseconds
        ack_expect_future =
          executor.schedule( this, ack_deadline_ms, TimeUnit.MILLISECONDS );
      }
      finally {
        ack_handler_lock.unlock();
      }
    }


    /**
     * Called when a call has not been acknowledged within the deadline.
     */
    @Override
    public void run() {
      InvokeNotAckedFlag flag;
      if ( received_an_ack ) {
        flag = SUBSEQUENT_NOT_ACKED_FLAG;
      }
      else flag = NEVER_ACKED_FLAG;


      ack_handler_lock.lock();
      try {
        return_slot.compareAndSet( null, flag );
        aborted_by_ack_fail = true;
        ack_expect_future = null;
      }
      finally {
        ack_handler_lock.unlock();
      }
    }


    public void dispose() {
      final ScheduledFuture<?> future = ack_expect_future;
      if ( future != null ) future.cancel( false );
    }
  }
}
TOP

Related Classes of com.starlight.intrepid.RemoteCallHandler$InvokeNotAckedFlag

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.