/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
package org.apache.mina.filter;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;
import org.apache.mina.common.ByteBuffer;
import org.apache.mina.common.ByteBufferProxy;
import org.apache.mina.common.IoFilterAdapter;
import org.apache.mina.common.IoFilterChain;
import org.apache.mina.common.IoFuture;
import org.apache.mina.common.IoFutureListener;
import org.apache.mina.common.IoHandler;
import org.apache.mina.common.IoSession;
import org.apache.mina.common.WriteFuture;
import org.apache.mina.common.support.DefaultWriteFuture;
import org.apache.mina.filter.support.SSLHandler;
import org.apache.mina.util.SessionLog;
/**
* An SSL filter that encrypts and decrypts the data exchanged in the session.
* Adding this filter triggers SSL handshake procedure immediately by sending
* a SSL 'hello' message, so you don't need to call
* {@link #startSSL(IoSession)} manually unless you are implementing StartTLS
* (see below).
* <p>
* This filter uses an {@link SSLEngine} which was introduced in Java 5, so
* Java version 5 or above is mandatory to use this filter. And please note that
* this filter only works for TCP/IP connections.
* <p>
* This filter logs debug information using {@link SessionLog}.
*
* <h2>Implementing StartTLS</h2>
* <p>
* You can use {@link #DISABLE_ENCRYPTION_ONCE} attribute to implement StartTLS:
* <pre>
* public void messageReceived(IoSession session, Object message) {
* if (message instanceof MyStartTLSRequest) {
* // Insert SSLFilter to get ready for handshaking
* session.getFilterChain().addFirst(sslFilter);
*
* // Disable encryption temporarilly.
* // This attribute will be removed by SSLFilter
* // inside the Session.write() call below.
* session.setAttribute(SSLFilter.DISABLE_ENCRYPTION_ONCE, Boolean.TRUE);
*
* // Write StartTLSResponse which won't be encrypted.
* session.write(new MyStartTLSResponse(OK));
*
* // Now DISABLE_ENCRYPTION_ONCE attribute is cleared.
* assert session.getAttribute(SSLFilter.DISABLE_ENCRYPTION_ONCE) == null;
* }
* }
* </pre>
*
* @author The Apache Directory Project (mina-dev@directory.apache.org)
* @version $Rev: 587373 $, $Date: 2007-10-23 11:54:05 +0900 (화, 23 10월 2007) $
*/
public class SSLFilter extends IoFilterAdapter {
/**
* A session attribute key that stores underlying {@link SSLSession}
* for each session.
*/
public static final String SSL_SESSION = SSLFilter.class.getName()
+ ".SSLSession";
/**
* A session attribute key that makes next one write request bypass
* this filter (not encrypting the data). This is a marker attribute,
* which means that you can put whatever as its value. ({@link Boolean#TRUE}
* is preferred.) The attribute is automatically removed from the session
* attribute map as soon as {@link IoSession#write(Object)} is invoked,
* and therefore should be put again if you want to make more messages
* bypass this filter. This is especially useful when you implement
* StartTLS.
*/
public static final String DISABLE_ENCRYPTION_ONCE = SSLFilter.class
.getName()
+ ".DisableEncryptionOnce";
/**
* A session attribute key that makes this filter to emit a
* {@link IoHandler#messageReceived(IoSession, Object)} event with a
* special message ({@link #SESSION_SECURED} or {@link #SESSION_UNSECURED}).
* This is a marker attribute, which means that you can put whatever as its
* value. ({@link Boolean#TRUE} is preferred.) By default, this filter
* doesn't emit any events related with SSL session flow control.
*/
public static final String USE_NOTIFICATION = SSLFilter.class.getName()
+ ".UseNotification";
/**
* A special message object which is emitted with a {@link IoHandler#messageReceived(IoSession, Object)}
* event when the session is secured and its {@link #USE_NOTIFICATION}
* attribute is set.
*/
public static final SSLFilterMessage SESSION_SECURED = new SSLFilterMessage(
"SESSION_SECURED");
/**
* A special message object which is emitted with a {@link IoHandler#messageReceived(IoSession, Object)}
* event when the session is not secure anymore and its {@link #USE_NOTIFICATION}
* attribute is set.
*/
public static final SSLFilterMessage SESSION_UNSECURED = new SSLFilterMessage(
"SESSION_UNSECURED");
private static final String NEXT_FILTER = SSLFilter.class.getName()
+ ".NextFilter";
private static final String SSL_HANDLER = SSLFilter.class.getName()
+ ".SSLHandler";
// SSL Context
private SSLContext sslContext;
private boolean client;
private boolean needClientAuth;
private boolean wantClientAuth;
private String[] enabledCipherSuites;
private String[] enabledProtocols;
/**
* Creates a new SSL filter using the specified {@link SSLContext}.
*/
public SSLFilter(SSLContext sslContext) {
if (sslContext == null) {
throw new NullPointerException("sslContext");
}
this.sslContext = sslContext;
}
/**
* Returns the underlying {@link SSLSession} for the specified session.
*
* @return <tt>null</tt> if no {@link SSLSession} is initialized yet.
*/
public SSLSession getSSLSession(IoSession session) {
return (SSLSession) session.getAttribute(SSL_SESSION);
}
/**
* (Re)starts SSL session for the specified <tt>session</tt> if not started yet.
* Please note that SSL session is automatically started by default, and therefore
* you don't need to call this method unless you've used TLS closure.
*
* @return <tt>true</tt> if the SSL session has been started, <tt>false</tt> if already started.
* @throws SSLException if failed to start the SSL session
*/
public boolean startSSL(IoSession session) throws SSLException {
SSLHandler handler = getSSLSessionHandler(session);
boolean started;
synchronized (handler) {
if (handler.isOutboundDone()) {
NextFilter nextFilter = (NextFilter) session
.getAttribute(NEXT_FILTER);
handler.destroy();
handler.init();
handler.handshake(nextFilter);
started = true;
} else {
started = false;
}
}
handler.flushScheduledEvents();
return started;
}
/**
* Returns <tt>true</tt> if and only if the specified <tt>session</tt> is
* encrypted/decrypted over SSL/TLS currently. This method will start
* to retun <tt>false</tt> after TLS <tt>close_notify</tt> message
* is sent and any messages written after then is not goinf to get encrypted.
*/
public boolean isSSLStarted(IoSession session) {
SSLHandler handler = getSSLSessionHandler0(session);
if (handler == null) {
return false;
}
synchronized (handler) {
return !handler.isOutboundDone();
}
}
/**
* Stops the SSL session by sending TLS <tt>close_notify</tt> message to
* initiate TLS closure.
*
* @param session the {@link IoSession} to initiate TLS closure
* @throws SSLException if failed to initiate TLS closure
* @throws IllegalArgumentException if this filter is not managing the specified session
*/
public WriteFuture stopSSL(IoSession session) throws SSLException {
SSLHandler handler = getSSLSessionHandler(session);
NextFilter nextFilter = (NextFilter) session.getAttribute(NEXT_FILTER);
WriteFuture future;
synchronized (handler) {
future = initiateClosure(nextFilter, session);
}
handler.flushScheduledEvents();
return future;
}
/**
* Returns <tt>true</tt> if the engine is set to use client mode
* when handshaking.
*/
public boolean isUseClientMode() {
return client;
}
/**
* Configures the engine to use client (or server) mode when handshaking.
*/
public void setUseClientMode(boolean clientMode) {
this.client = clientMode;
}
/**
* Returns <tt>true</tt> if the engine will <em>require</em> client authentication.
* This option is only useful to engines in the server mode.
*/
public boolean isNeedClientAuth() {
return needClientAuth;
}
/**
* Configures the engine to <em>require</em> client authentication.
* This option is only useful for engines in the server mode.
*/
public void setNeedClientAuth(boolean needClientAuth) {
this.needClientAuth = needClientAuth;
}
/**
* Returns <tt>true</tt> if the engine will <em>request</em> client authentication.
* This option is only useful to engines in the server mode.
*/
public boolean isWantClientAuth() {
return wantClientAuth;
}
/**
* Configures the engine to <em>request</em> client authentication.
* This option is only useful for engines in the server mode.
*/
public void setWantClientAuth(boolean wantClientAuth) {
this.wantClientAuth = wantClientAuth;
}
/**
* Returns the list of cipher suites to be enabled when {@link SSLEngine}
* is initialized.
*
* @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
*/
public String[] getEnabledCipherSuites() {
return enabledCipherSuites;
}
/**
* Sets the list of cipher suites to be enabled when {@link SSLEngine}
* is initialized.
*
* @param cipherSuites <tt>null</tt> means 'use {@link SSLEngine}'s default.'
*/
public void setEnabledCipherSuites(String[] cipherSuites) {
this.enabledCipherSuites = cipherSuites;
}
/**
* Returns the list of protocols to be enabled when {@link SSLEngine}
* is initialized.
*
* @return <tt>null</tt> means 'use {@link SSLEngine}'s default.'
*/
public String[] getEnabledProtocols() {
return enabledProtocols;
}
/**
* Sets the list of protocols to be enabled when {@link SSLEngine}
* is initialized.
*
* @param protocols <tt>null</tt> means 'use {@link SSLEngine}'s default.'
*/
public void setEnabledProtocols(String[] protocols) {
this.enabledProtocols = protocols;
}
public void onPreAdd(IoFilterChain parent, String name,
NextFilter nextFilter) throws SSLException {
if (parent.contains(SSLFilter.class)) {
throw new IllegalStateException(
"A filter chain cannot contain more than one SSLFilter.");
}
IoSession session = parent.getSession();
session.setAttribute(NEXT_FILTER, nextFilter);
// Create an SSL handler and start handshake.
SSLHandler handler = new SSLHandler(this, sslContext, session);
session.setAttribute(SSL_HANDLER, handler);
}
public void onPostAdd(IoFilterChain parent, String name,
NextFilter nextFilter) throws SSLException {
SSLHandler handler = getSSLSessionHandler(parent.getSession());
synchronized (handler) {
handler.handshake(nextFilter);
}
handler.flushScheduledEvents();
}
public void onPreRemove(IoFilterChain parent, String name,
NextFilter nextFilter) throws SSLException {
IoSession session = parent.getSession();
stopSSL(session);
session.removeAttribute(NEXT_FILTER);
session.removeAttribute(SSL_HANDLER);
}
// IoFilter impl.
public void sessionClosed(NextFilter nextFilter, IoSession session)
throws SSLException {
SSLHandler handler = getSSLSessionHandler(session);
try {
synchronized (handler) {
if (isSSLStarted(session)) {
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " Closed: "
+ getSSLSessionHandler(session));
}
}
// release resources
handler.destroy();
}
handler.flushScheduledEvents();
} finally {
// notify closed session
nextFilter.sessionClosed(session);
}
}
public void messageReceived(NextFilter nextFilter, IoSession session,
Object message) throws SSLException {
SSLHandler handler = getSSLSessionHandler(session);
synchronized (handler) {
if (!isSSLStarted(session) && handler.isInboundDone()) {
handler.scheduleMessageReceived(nextFilter, message);
} else {
ByteBuffer buf = (ByteBuffer) message;
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " Data Read: " + handler + " ("
+ buf + ')');
}
try {
// forward read encrypted data to SSL handler
handler.messageReceived(nextFilter, buf.buf());
// Handle data to be forwarded to application or written to net
handleSSLData(nextFilter, handler);
if (handler.isInboundDone()) {
if (handler.isOutboundDone()) {
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session,
" SSL Session closed.");
}
handler.destroy();
} else {
initiateClosure(nextFilter, session);
}
if (buf.hasRemaining()) {
// Forward the data received after closure.
handler.scheduleMessageReceived(nextFilter, buf);
}
}
} catch (SSLException ssle) {
if (!handler.isHandshakeComplete()) {
SSLException newSSLE = new SSLHandshakeException(
"SSL handshake failed.");
newSSLE.initCause(ssle);
ssle = newSSLE;
}
throw ssle;
}
}
}
handler.flushScheduledEvents();
}
public void messageSent(NextFilter nextFilter, IoSession session,
Object message) {
if (message instanceof EncryptedBuffer) {
EncryptedBuffer buf = (EncryptedBuffer) message;
buf.release();
nextFilter.messageSent(session, buf.originalBuffer);
} else {
// ignore extra buffers used for handshaking
}
}
public void filterWrite(NextFilter nextFilter, IoSession session,
WriteRequest writeRequest) throws SSLException {
boolean needsFlush = true;
SSLHandler handler = getSSLSessionHandler(session);
synchronized (handler) {
if (!isSSLStarted(session)) {
handler.scheduleFilterWrite(nextFilter,
writeRequest);
}
// Don't encrypt the data if encryption is disabled.
else if (session.containsAttribute(DISABLE_ENCRYPTION_ONCE)) {
// Remove the marker attribute because it is temporary.
session.removeAttribute(DISABLE_ENCRYPTION_ONCE);
handler.scheduleFilterWrite(nextFilter,
writeRequest);
} else {
// Otherwise, encrypt the buffer.
ByteBuffer buf = (ByteBuffer) writeRequest.getMessage();
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " Filtered Write: " + handler);
}
if (handler.isWritingEncryptedData()) {
// data already encrypted; simply return buffer
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " already encrypted: "
+ buf);
}
handler.scheduleFilterWrite(nextFilter,
writeRequest);
} else if (handler.isHandshakeComplete()) {
// SSL encrypt
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " encrypt: " + buf);
}
int pos = buf.position();
handler.encrypt(buf.buf());
buf.position(pos);
ByteBuffer encryptedBuffer = new EncryptedBuffer(SSLHandler
.copy(handler.getOutNetBuffer()), buf);
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " encrypted buf: "
+ encryptedBuffer);
}
handler.scheduleFilterWrite(nextFilter,
new WriteRequest(encryptedBuffer, writeRequest
.getFuture()));
} else {
if (!session.isConnected()) {
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session,
" Write request on closed session.");
}
} else {
if (SessionLog.isDebugEnabled(session)) {
SessionLog
.debug(session,
" Handshaking is not complete yet. Buffering write request.");
}
handler.schedulePreHandshakeWriteRequest(nextFilter,
writeRequest);
}
needsFlush = false;
}
}
}
if (needsFlush) {
handler.flushScheduledEvents();
}
}
public void filterClose(final NextFilter nextFilter, final IoSession session)
throws SSLException {
SSLHandler handler = getSSLSessionHandler0(session);
if (handler == null) {
// The connection might already have closed, or
// SSL might have not started yet.
nextFilter.filterClose(session);
return;
}
WriteFuture future = null;
try {
synchronized (handler) {
if (isSSLStarted(session)) {
future = initiateClosure(nextFilter, session);
future.addListener(new IoFutureListener() {
public void operationComplete(IoFuture future) {
nextFilter.filterClose(session);
}
});
}
}
handler.flushScheduledEvents();
} finally {
if (future == null) {
nextFilter.filterClose(session);
}
}
}
private WriteFuture initiateClosure(NextFilter nextFilter, IoSession session)
throws SSLException {
SSLHandler handler = getSSLSessionHandler(session);
// if already shut down
if (!handler.closeOutbound()) {
return DefaultWriteFuture.newNotWrittenFuture(session);
}
// there might be data to write out here?
WriteFuture future = handler.writeNetBuffer(nextFilter);
if (handler.isInboundDone()) {
handler.destroy();
}
if (session.containsAttribute(USE_NOTIFICATION)) {
handler.scheduleMessageReceived(nextFilter, SESSION_UNSECURED);
}
return future;
}
// Utiliities
private void handleSSLData(NextFilter nextFilter, SSLHandler handler)
throws SSLException {
// Flush any buffered write requests occurred before handshaking.
if (handler.isHandshakeComplete()) {
handler.flushPreHandshakeEvents();
}
// Write encrypted data to be written (if any)
handler.writeNetBuffer(nextFilter);
// handle app. data read (if any)
handleAppDataRead(nextFilter, handler);
}
private void handleAppDataRead(NextFilter nextFilter, SSLHandler handler) {
IoSession session = handler.getSession();
handler.getAppBuffer().flip();
if (!handler.getAppBuffer().hasRemaining()) {
handler.getAppBuffer().clear();
return;
}
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " appBuffer: " + handler.getAppBuffer());
}
// forward read app data
ByteBuffer readBuffer = SSLHandler.copy(handler.getAppBuffer());
handler.getAppBuffer().clear();
if (SessionLog.isDebugEnabled(session)) {
SessionLog.debug(session, " app data read: " + readBuffer + " ("
+ readBuffer.getHexDump() + ')');
}
handler.scheduleMessageReceived(nextFilter, readBuffer);
}
private SSLHandler getSSLSessionHandler(IoSession session) {
SSLHandler handler = getSSLSessionHandler0(session);
if (handler == null) {
throw new IllegalStateException();
}
if (handler.getParent() != this) {
throw new IllegalArgumentException("Not managed by this filter.");
}
return handler;
}
private SSLHandler getSSLSessionHandler0(IoSession session) {
return (SSLHandler) session.getAttribute(SSL_HANDLER);
}
/**
* A message that is sent from {@link SSLFilter} when the connection became
* secure or is not secure anymore.
*
* @author The Apache Directory Project (mina-dev@directory.apache.org)
* @version $Rev: 587373 $, $Date: 2007-10-23 11:54:05 +0900 (화, 23 10월 2007) $
*/
public static class SSLFilterMessage {
private final String name;
private SSLFilterMessage(String name) {
this.name = name;
}
public String toString() {
return name;
}
}
private static class EncryptedBuffer extends ByteBufferProxy {
private final ByteBuffer originalBuffer;
private EncryptedBuffer(ByteBuffer buf, ByteBuffer originalBuffer) {
super(buf);
this.originalBuffer = originalBuffer;
}
}
}