package org.mockserver.proxy.http.direct;
import com.google.common.base.Charsets;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import org.mockserver.logging.LoggingHandler;
import org.mockserver.proxy.http.relay.BasicHttpDecoder;
import org.mockserver.proxy.http.relay.ProxyRelayHandler;
import org.mockserver.proxy.interceptor.Interceptor;
import org.mockserver.proxy.interceptor.ResponseInterceptor;
import org.mockserver.socket.SSLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLEngine;
import java.net.InetSocketAddress;
public class DirectProxyUpstreamHandler extends ChannelDuplexHandler {
private final Logger logger;
private final InetSocketAddress remoteSocketAddress;
private final boolean secure;
private final int bufferedCapacity;
private final Interceptor interceptor;
private volatile Channel outboundChannel;
private volatile ByteBuf channelBuffer;
private volatile boolean bufferedMode;
private volatile boolean flushedBuffer;
private volatile Integer contentLength;
private volatile int contentSoFar;
private volatile boolean flushContent;
public DirectProxyUpstreamHandler(InetSocketAddress remoteSocketAddress, boolean secure, int bufferedCapacity, Interceptor interceptor, String loggerName) {
this.remoteSocketAddress = remoteSocketAddress;
this.secure = secure;
this.bufferedCapacity = bufferedCapacity;
this.interceptor = interceptor;
this.logger = LoggerFactory.getLogger(loggerName);
bufferedMode = bufferedCapacity > 0;
flushedBuffer = false;
contentLength = null;
contentSoFar = 0;
flushContent = false;
}
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
this.channelBuffer = Unpooled.directBuffer(bufferedCapacity);
super.handlerAdded(ctx);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
if (channelBuffer.refCnt() >= 1) {
channelBuffer.release();
}
super.handlerRemoved(ctx);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
final Channel inboundChannel = ctx.channel();
// Start the connection attempt.
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(inboundChannel.eventLoop())
.channel(ctx.channel().getClass())
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
// Create a default pipeline implementation.
ChannelPipeline pipeline = ch.pipeline();
// add logging
if (logger.isDebugEnabled()) {
pipeline.addLast("logger", new LoggingHandler(logger));
}
// add HTTPS proxy -> server support
if (secure) {
SSLEngine engine = SSLFactory.getInstance().sslContext().createSSLEngine();
engine.setUseClientMode(true);
pipeline.addLast("proxy -> server ssl", new SslHandler(engine));
}
// add handler
pipeline.addLast(new ProxyRelayHandler(inboundChannel, bufferedCapacity, new ResponseInterceptor(), logger));
}
})
.option(ChannelOption.AUTO_READ, false);
ChannelFuture channelFuture = bootstrap.connect(remoteSocketAddress);
outboundChannel = channelFuture.channel();
channelFuture.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
channelBuffer.clear();
bufferedMode = bufferedCapacity > 0;
flushedBuffer = false;
// connection complete start to read first data
inboundChannel.read();
} else {
// Close the connection if the connection attempt has failed.
logger.warn("Failed to connect to: " + remoteSocketAddress, future.cause());
inboundChannel.close();
}
}
});
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (outboundChannel.isActive()) {
if (bufferedMode && channelBuffer.isReadable()) {
flushedBuffer = true;
logger.debug("CHANNEL INACTIVE: " + channelBuffer.toString(Charsets.UTF_8));
outboundChannel.writeAndFlush(interceptor.intercept(ctx, channelBuffer, logger)).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
channelBuffer.clear();
// flushed entire buffer upstream so close connection
outboundChannel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
} else {
logger.warn("Failed to send flush channel buffer", future.cause());
future.channel().close();
}
}
});
} else {
outboundChannel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
}
}
}
@Override
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
if (bufferedMode && outboundChannel.isActive() && channelBuffer.isReadable()) {
flushedBuffer = true;
logger.debug("CHANNEL READ COMPLETE: " + channelBuffer.toString(Charsets.UTF_8));
outboundChannel.writeAndFlush(interceptor.intercept(ctx, channelBuffer, logger)).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
channelBuffer.clear();
} else {
logger.warn("Failed to write to: " + remoteSocketAddress, future.cause());
future.channel().close();
}
}
});
}
super.channelReadComplete(ctx);
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
if (msg instanceof ByteBuf) {
final ByteBuf chunk = (ByteBuf) msg;
if (flushedBuffer) {
bufferedMode = false;
}
if (bufferedMode) {
flushContent = false;
if (contentLength != null) {
contentSoFar += chunk.readableBytes();
} else {
// find content length
BasicHttpDecoder basicHttpDecoder = new BasicHttpDecoder(Unpooled.copiedBuffer(chunk));
contentLength = basicHttpDecoder.getContentLength();
contentSoFar = (chunk.readableBytes() - basicHttpDecoder.getContentStart());
}
if (logger.isTraceEnabled()) {
logger.trace("CHUNK: ---\n-" + System.getProperty("line.separator") + Unpooled.copiedBuffer(chunk).toString(Charsets.UTF_8) + "\n-" + System.getProperty("line.separator"));
logger.trace("CONTENT-SO-FAR-PRE-CHUNK: --- " + (contentSoFar - Unpooled.copiedBuffer(chunk).toString(Charsets.UTF_8).length()));
logger.trace("CHUNK-SIZE: --- " + chunk.readableBytes());
logger.trace("CONTENT-SO-FAR-PRE-CHUNK: --- " + contentSoFar);
if (contentLength != null) {
logger.trace("CONTENT-REMAINING: --- " + (contentLength - contentSoFar));
logger.trace("CONTENT-LENGTH: --- " + contentLength);
}
}
if (contentLength != null) {
logger.trace("Flushing buffer as all content received");
flushContent = (contentSoFar >= contentLength) || (chunk.readableBytes() == 0);
}
try {
channelBuffer.writeBytes(chunk);
ctx.channel().read();
} catch (IndexOutOfBoundsException iobe) {
logger.trace("Flushing buffer upstream and switching to chunked mode as downstream response too large");
bufferedMode = false;
// write and flush buffer upstream
if (outboundChannel.isActive() && channelBuffer.isReadable()) {
logger.debug("CHANNEL READ EX: " + chunk.toString(Charsets.UTF_8));
outboundChannel.writeAndFlush(channelBuffer).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
// write and flush this chunk upstream in case this single chunk is too large for buffer
channelRead(ctx, chunk);
} else {
logger.warn("Failed to write to: " + remoteSocketAddress, future.cause());
future.channel().close();
}
}
});
}
}
} else {
bufferedMode = false;
if (outboundChannel.isActive()) {
logger.debug("CHANNEL READ NOT-BUFFERING: " + chunk.toString(Charsets.UTF_8));
outboundChannel.writeAndFlush(chunk).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
// was able to flush out data, start to read the next chunk
ctx.channel().read();
} else {
logger.warn("Failed to write to: " + remoteSocketAddress, future.cause());
future.channel().close();
}
}
});
}
}
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.warn("Exception caught by http direct proxy handler closing pipeline", cause);
Channel ch = ctx.channel();
if (ch.isActive()) {
ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
}
}
}