package com.subgraph.orchid.sockets.sslengine;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.SocketException;
import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
public class SSLEngineManager {
private final static Logger logger = Logger.getLogger(SSLEngineManager.class.getName());
private final SSLEngine engine;
private final InputStream input;
private final OutputStream output;
private final ByteBuffer peerApplicationBuffer;
private final ByteBuffer peerNetworkBuffer;
private final ByteBuffer myApplicationBuffer;
private final ByteBuffer myNetworkBuffer;
private final HandshakeCallbackHandler handshakeCallback;
private boolean handshakeStarted = false;
SSLEngineManager(SSLEngine engine, HandshakeCallbackHandler handshakeCallback, InputStream input, OutputStream output) {
this.engine = engine;
this.handshakeCallback = handshakeCallback;
this.input = input;
this.output = output;
final SSLSession session = engine.getSession();
this.peerApplicationBuffer = createApplicationBuffer(session);
this.peerNetworkBuffer = createPacketBuffer(session);
this.myApplicationBuffer = createApplicationBuffer(session);
this.myNetworkBuffer = createPacketBuffer(session);
}
private static ByteBuffer createApplicationBuffer(SSLSession session) {
return createBuffer(session.getApplicationBufferSize());
}
private static ByteBuffer createPacketBuffer(SSLSession session) {
return createBuffer(session.getPacketBufferSize());
}
private static ByteBuffer createBuffer(int sz) {
final byte[] array = new byte[sz];
return ByteBuffer.wrap(array);
}
void startHandshake() throws IOException {
logger.fine("startHandshake()");
handshakeStarted = true;
engine.beginHandshake();
runHandshake();
}
ByteBuffer getSendBuffer() {
return myApplicationBuffer;
}
ByteBuffer getRecvBuffer() {
return peerApplicationBuffer;
}
int write() throws IOException {
logger.fine("write()");
if(!handshakeStarted) {
startHandshake();
}
final int p = myApplicationBuffer.position();
if(p == 0) {
return 0;
}
myNetworkBuffer.clear();
myApplicationBuffer.flip();
final SSLEngineResult result = engine.wrap(myApplicationBuffer, myNetworkBuffer);
myApplicationBuffer.compact();
if(logger.isLoggable(Level.FINE)) {
logResult(result);
}
switch(result.getStatus()) {
case BUFFER_OVERFLOW:
throw new BufferOverflowException();
case BUFFER_UNDERFLOW:
throw new BufferUnderflowException();
case CLOSED:
throw new SSLException("SSLEngine is closed");
case OK:
break;
default:
break;
}
flush();
if(runHandshake()) {
write();
}
return p - myApplicationBuffer.position();
}
// either return -1 or peerApplicationBuffer has data to read
int read() throws IOException {
logger.fine("read()");
if(!handshakeStarted) {
startHandshake();
}
if(engine.isInboundDone()) {
return -1;
}
final int n = networkReadBuffer(peerNetworkBuffer);
if(n == -1) {
return -1;
}
final int p = peerApplicationBuffer.position();
peerNetworkBuffer.flip();
final SSLEngineResult result = engine.unwrap(peerNetworkBuffer, peerApplicationBuffer);
peerNetworkBuffer.compact();
if(logger.isLoggable(Level.FINE)) {
logResult(result);
}
switch(result.getStatus()) {
case BUFFER_OVERFLOW:
throw new BufferOverflowException();
case BUFFER_UNDERFLOW:
return 0; // <-- illegal return according to invariant
case CLOSED:
input.close();
break;
case OK:
break;
default:
break;
}
runHandshake();
if(n == -1) { // <-- can't happen
engine.closeInbound();
}
if(engine.isInboundDone()) {
return -1;
}
return peerApplicationBuffer.position() - p;
}
void close() throws IOException {
try {
flush();
if(!engine.isOutboundDone()) {
engine.closeOutbound();
runHandshake();
} else if(!engine.isInboundDone()) {
engine.closeInbound();
runHandshake();
}
} finally {
output.close();
}
}
void flush() throws IOException {
myNetworkBuffer.flip();
networkWriteBuffer(myNetworkBuffer);
myNetworkBuffer.compact();
}
private boolean runHandshake() throws IOException {
boolean handshakeRan = false;
while(true) {
if(!processHandshake()) {
return handshakeRan;
} else {
handshakeRan = true;
}
}
}
private boolean processHandshake() throws IOException {
final HandshakeStatus hs = engine.getHandshakeStatus();
logger.fine("processHandshake() hs = "+ hs);
switch(hs) {
case NEED_TASK:
synchronousRunDelegatedTasks();
return processHandshake();
case NEED_UNWRAP:
return handshakeUnwrap();
case NEED_WRAP:
return handshakeWrap();
default:
return false;
}
}
private void synchronousRunDelegatedTasks() {
logger.fine("runDelegatedTasks()");
while(true) {
Runnable r = engine.getDelegatedTask();
if(r == null) {
return;
}
logger.fine("Running a task: "+ r);
r.run();
}
}
private boolean handshakeUnwrap() throws IOException {
logger.fine("handshakeUnwrap()");
if(!engine.isInboundDone() && peerNetworkBuffer.position() == 0) {
if(networkReadBuffer(peerNetworkBuffer) < 0) {
return false;
}
}
peerNetworkBuffer.flip();
final SSLEngineResult result = engine.unwrap(peerNetworkBuffer, peerApplicationBuffer);
peerNetworkBuffer.compact();
if(logger.isLoggable(Level.FINE)) {
logResult(result);
}
if(result.getHandshakeStatus() == HandshakeStatus.FINISHED) {
handshakeFinished();
}
switch(result.getStatus()) {
case CLOSED:
if(engine.isOutboundDone()) {
output.close();
}
return false;
case OK:
return true;
case BUFFER_UNDERFLOW:
if(networkReadBuffer(peerNetworkBuffer) < 0) {
return false;
}
return true;
default:
return false;
}
}
private boolean handshakeWrap() throws IOException {
logger.fine("handshakeWrap()");
myApplicationBuffer.flip();
final SSLEngineResult result = engine.wrap(myApplicationBuffer, myNetworkBuffer);
myApplicationBuffer.compact();
if(logger.isLoggable(Level.FINE)) {
logResult(result);
}
if(result.getHandshakeStatus() == HandshakeStatus.FINISHED) {
handshakeFinished();
}
if(result.getStatus() == Status.CLOSED) {
try {
flush();
} catch (SocketException e) {
e.printStackTrace();
}
} else {
flush();
}
switch(result.getStatus()) {
case CLOSED:
if(engine.isOutboundDone()) {
output.close();
}
return false;
case OK:
return true;
default:
return false;
}
}
private void logResult(SSLEngineResult result) {
logger.fine("Result status="+result.getStatus() + " hss="+ result.getHandshakeStatus() + " consumed = "+ result.bytesConsumed() + " produced = "+ result.bytesProduced());
}
private void handshakeFinished() {
if(handshakeCallback != null) {
handshakeCallback.handshakeCompleted();
}
}
private void networkWriteBuffer(ByteBuffer buffer) throws IOException {
final byte[] bs = buffer.array();
final int off = buffer.position();
final int len = buffer.limit() - off;
logger.fine("networkWriteBuffer(b, "+ off + ", "+ len +")");
output.write(bs, off, len);
output.flush();
buffer.position(buffer.limit());
}
private int networkReadBuffer(ByteBuffer buffer) throws IOException {
final byte[] bs = buffer.array();
final int off = buffer.position();
final int len = buffer.limit() - off;
final int n = input.read(bs, off, len);
if(n != -1) {
buffer.position(off + n);
}
logger.fine("networkReadBuffer(b, "+ off +", "+ len +") = "+ n);
return n;
}
}