Package com.subgraph.orchid.sockets.sslengine

Source Code of com.subgraph.orchid.sockets.sslengine.SSLEngineManager

package com.subgraph.orchid.sockets.sslengine;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.SocketException;
import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;

public class SSLEngineManager {
  private final static Logger logger = Logger.getLogger(SSLEngineManager.class.getName());
 
  private final SSLEngine engine;
  private final InputStream input;
  private final OutputStream output;
 
  private final ByteBuffer peerApplicationBuffer;
  private final ByteBuffer peerNetworkBuffer;
  private final ByteBuffer myApplicationBuffer;
  private final ByteBuffer myNetworkBuffer;
 
  private final HandshakeCallbackHandler handshakeCallback;
 
  private boolean handshakeStarted = false;
 
 
  SSLEngineManager(SSLEngine engine, HandshakeCallbackHandler handshakeCallback, InputStream input, OutputStream output) {
    this.engine = engine;
    this.handshakeCallback = handshakeCallback;
    this.input = input;
    this.output = output;
    final SSLSession session = engine.getSession();
    this.peerApplicationBuffer = createApplicationBuffer(session);
    this.peerNetworkBuffer = createPacketBuffer(session);
    this.myApplicationBuffer = createApplicationBuffer(session);
    this.myNetworkBuffer = createPacketBuffer(session);
  }
 
  private static ByteBuffer createApplicationBuffer(SSLSession session) {
    return createBuffer(session.getApplicationBufferSize());
  }
 
  private static ByteBuffer createPacketBuffer(SSLSession session) {
    return createBuffer(session.getPacketBufferSize());
  }
 
  private static ByteBuffer createBuffer(int sz) {
    final byte[] array = new byte[sz];
    return ByteBuffer.wrap(array);
  }
 
  void startHandshake() throws IOException {
    logger.fine("startHandshake()");
    handshakeStarted = true;
    engine.beginHandshake();
    runHandshake();
  }

  ByteBuffer getSendBuffer() {
    return myApplicationBuffer;
  }
 
  ByteBuffer getRecvBuffer() {
    return peerApplicationBuffer;
  }

 
  int write() throws IOException {
    logger.fine("write()");
    if(!handshakeStarted) {
      startHandshake();
    }
    final int p = myApplicationBuffer.position();
    if(p == 0) {
      return 0;
    }
    myNetworkBuffer.clear();
    myApplicationBuffer.flip();
    final SSLEngineResult result = engine.wrap(myApplicationBuffer, myNetworkBuffer);
    myApplicationBuffer.compact();
    if(logger.isLoggable(Level.FINE)) {
      logResult(result);
    }
   
    switch(result.getStatus()) {
    case BUFFER_OVERFLOW:
      throw new BufferOverflowException();
    case BUFFER_UNDERFLOW:
      throw new BufferUnderflowException();
    case CLOSED:
      throw new SSLException("SSLEngine is closed");

    case OK:
      break;
    default:
      break;
    }
   
    flush();
    if(runHandshake()) {
      write();
    }
   
    return p - myApplicationBuffer.position();

  }

  // either return -1 or peerApplicationBuffer has data to read
  int read() throws IOException {
    logger.fine("read()");
    if(!handshakeStarted) {
      startHandshake();
    }
   
    if(engine.isInboundDone()) {
      return -1;
    }
   
    final int n = networkReadBuffer(peerNetworkBuffer);
    if(n == -1) {
      return -1;
    }
    final int p = peerApplicationBuffer.position();
   
    peerNetworkBuffer.flip();
    final SSLEngineResult result = engine.unwrap(peerNetworkBuffer, peerApplicationBuffer);
    peerNetworkBuffer.compact();
    if(logger.isLoggable(Level.FINE)) {
      logResult(result);
    }
   
    switch(result.getStatus()) {
    case BUFFER_OVERFLOW:
      throw new BufferOverflowException();
     
    case BUFFER_UNDERFLOW:
      return 0; // <-- illegal return according to invariant
     
    case CLOSED:
      input.close();
      break;
    case OK:
      break;
    default:
      break;
    }

    runHandshake();
   
    if(n == -1) { // <-- can't happen
      engine.closeInbound();
    }
    if(engine.isInboundDone()) {
      return -1;
    }
    return peerApplicationBuffer.position() - p;
  }
 
  void close() throws IOException {
    try {
      flush();
      if(!engine.isOutboundDone()) {
        engine.closeOutbound();
        runHandshake();
      } else if(!engine.isInboundDone()) {
        engine.closeInbound();
        runHandshake();
      }
    } finally {
      output.close();
    }
  }
 
  void flush() throws IOException {
    myNetworkBuffer.flip();
    networkWriteBuffer(myNetworkBuffer);
    myNetworkBuffer.compact();
  }

 
  private boolean runHandshake() throws IOException {
    boolean handshakeRan = false;
    while(true) {
      if(!processHandshake()) {
        return handshakeRan;
      } else {
        handshakeRan = true;
      }
    }
  }
 
  private boolean processHandshake() throws IOException {
    final HandshakeStatus hs = engine.getHandshakeStatus();
    logger.fine("processHandshake() hs = "+ hs);
    switch(hs) {
    case NEED_TASK:
      synchronousRunDelegatedTasks();
      return processHandshake();

    case NEED_UNWRAP:
      return handshakeUnwrap();
     
    case NEED_WRAP:
      return handshakeWrap();

    default:
      return false;
    }
  }
 
  private void synchronousRunDelegatedTasks() {
    logger.fine("runDelegatedTasks()");
    while(true) {
      Runnable r = engine.getDelegatedTask();
      if(r == null) {
        return;
      }
      logger.fine("Running a task: "+ r);
      r.run();
    }
  }
 
  private boolean handshakeUnwrap() throws IOException {
    logger.fine("handshakeUnwrap()");
   
    if(!engine.isInboundDone() && peerNetworkBuffer.position() == 0) {
      if(networkReadBuffer(peerNetworkBuffer) < 0) {
        return false;
      }
    }
    peerNetworkBuffer.flip();
    final SSLEngineResult result = engine.unwrap(peerNetworkBuffer, peerApplicationBuffer);
    peerNetworkBuffer.compact();
   
    if(logger.isLoggable(Level.FINE)) {
      logResult(result);
    }

    if(result.getHandshakeStatus() == HandshakeStatus.FINISHED) {
      handshakeFinished();
    }
    switch(result.getStatus()) {

    case CLOSED:
      if(engine.isOutboundDone()) {
        output.close();
      }
      return false;
    case OK:
      return true;
    case BUFFER_UNDERFLOW:
      if(networkReadBuffer(peerNetworkBuffer) < 0) {
        return false;
      }
      return true;
    default:
      return false;
    }
  }
 
  private boolean handshakeWrap() throws IOException {
    logger.fine("handshakeWrap()");
    myApplicationBuffer.flip();
    final SSLEngineResult result = engine.wrap(myApplicationBuffer, myNetworkBuffer);
    myApplicationBuffer.compact();
    if(logger.isLoggable(Level.FINE)) {
      logResult(result);
    }

    if(result.getHandshakeStatus() == HandshakeStatus.FINISHED) {
      handshakeFinished();
    }
   
    if(result.getStatus() == Status.CLOSED) {
      try {
        flush();
      } catch (SocketException e) {
        e.printStackTrace();
      }
    } else {
      flush();
    }
   
    switch(result.getStatus()) {
    case CLOSED:
      if(engine.isOutboundDone()) {
        output.close();
      }
      return false;

    case OK:
      return true;

    default:
      return false;
   
    }
  }

  private void logResult(SSLEngineResult result) {
    logger.fine("Result status="+result.getStatus() + " hss="+ result.getHandshakeStatus() + " consumed = "+ result.bytesConsumed() + " produced = "+ result.bytesProduced());
  }
 
  private void handshakeFinished() {
    if(handshakeCallback != null) {
      handshakeCallback.handshakeCompleted();
    }
  }
 
  private void networkWriteBuffer(ByteBuffer buffer) throws IOException {
    final byte[] bs = buffer.array();
    final int off = buffer.position();
    final int len = buffer.limit() - off;
    logger.fine("networkWriteBuffer(b, "+ off + ", "+ len +")");
    output.write(bs, off, len);
    output.flush();
    buffer.position(buffer.limit());
  }
 
  private int networkReadBuffer(ByteBuffer buffer) throws IOException {
    final byte[] bs = buffer.array();
    final int off = buffer.position();
    final int len = buffer.limit() - off;

    final int n = input.read(bs, off, len);
    if(n != -1) {
      buffer.position(off + n);
    }
    logger.fine("networkReadBuffer(b, "+ off +", "+ len +") = "+ n);
    return n;
  }
 
}
TOP

Related Classes of com.subgraph.orchid.sockets.sslengine.SSLEngineManager

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.