package com.proofpoint.http.client.jetty;
import com.google.common.collect.ImmutableList;
import com.google.common.io.ByteArrayDataOutput;
import com.google.common.io.ByteStreams;
import com.google.common.net.HostAndPort;
import com.google.common.net.InetAddresses;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import java.io.Closeable;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.util.concurrent.MoreExecutors.listeningDecorator;
import static com.proofpoint.concurrent.Threads.threadsNamed;
import static java.util.concurrent.Executors.newCachedThreadPool;
public class TestingSocksProxy
implements Closeable
{
private static final int SOCKS_4_SUCCESS = 0x5a;
private static final int SOCKS_4_FAILED = 0x5b;
private final int bindPort;
private HostAndPort hostAndPort;
private ListeningExecutorService executorService;
private ServerSocket serverSocket;
public TestingSocksProxy()
{
this(0);
}
public TestingSocksProxy(int bindPort)
{
this.bindPort = bindPort;
}
public synchronized HostAndPort getHostAndPort()
{
checkState(hostAndPort != null, "%s is not running", getClass().getName());
return hostAndPort;
}
public synchronized TestingSocksProxy start()
throws IOException
{
checkState(serverSocket == null, "%s already started", getClass().getName());
try {
serverSocket = new ServerSocket(bindPort);
hostAndPort = HostAndPort.fromParts(serverSocket.getInetAddress().getHostAddress(), serverSocket.getLocalPort());
executorService = listeningDecorator(newCachedThreadPool(threadsNamed("socks-proxy-" + serverSocket.getLocalPort() + "-%s")));
executorService.execute(new SocksProxyAcceptor(serverSocket, executorService));
return this;
}
catch (Throwable e) {
close();
throw e;
}
}
public synchronized void close()
{
hostAndPort = null;
if (serverSocket != null) {
closeIgnoreException(serverSocket);
serverSocket = null;
}
if (executorService != null) {
executorService.shutdownNow();
executorService = null;
}
}
private static class SocksProxyAcceptor
implements Runnable
{
private final ServerSocket serverSocket;
private final ListeningExecutorService executorService;
private final AtomicBoolean closed = new AtomicBoolean();
private SocksProxyAcceptor(ServerSocket serverSocket, ListeningExecutorService executorService)
{
this.serverSocket = serverSocket;
this.executorService = executorService;
}
@Override
public void run()
{
while (!closed.get() && !serverSocket.isClosed() && !Thread.currentThread().isInterrupted()) {
try {
Socket socket = serverSocket.accept();
executorService.execute(new SocksProxyWorker(socket, executorService));
}
catch (IOException ignored) {
// doesn't really matter
}
}
closeIgnoreException(serverSocket);
}
}
private static class SocksProxyWorker
implements Runnable
{
private final Socket socket;
private final ListeningExecutorService executor;
private SocksProxyWorker(Socket socket, ListeningExecutorService executor)
{
this.socket = socket;
this.executor = executor;
}
@Override
public void run()
{
try {
connect();
}
catch (IOException e) {
// ignored nothing we can do about this
closeIgnoreException(socket);
}
catch (Throwable e) {
closeIgnoreException(socket);
throw e;
}
}
private void connect()
throws IOException
{
DataInputStream sourceInput = new DataInputStream(socket.getInputStream());
DataOutputStream sourceOutput = new DataOutputStream(socket.getOutputStream());
// field 1: SOCKS version number, 1 byte
int version = sourceInput.read();
if (version == 4) {
socks4(sourceInput, sourceOutput);
}
else if (version == 5) {
socks5(sourceInput, sourceOutput);
}
// unsupported version, just close the socket
}
private void socks4(DataInputStream sourceInput, DataOutputStream sourceOutput)
throws IOException
{
// field 2: command code, 1 byte: 0x01 = connect, 0x02 = bind
int command = sourceInput.read();
// field 3: network byte order port number, 2 bytes
int port = sourceInput.readUnsignedShort();
// field 4: network byte order IP address, 4 bytes
int address = sourceInput.readInt();
// field 5: the user ID string, variable length, terminated with a null (0x00)
while (sourceInput.read() != 0) {
// ignored
}
if (command != 1) {
// we only support connect requests
responseSocks4(sourceOutput, SOCKS_4_FAILED, 0, 0);
return;
}
// Socks 4a: if address is 0x0000_00xx where xx is not 0, we have a domain name
String domainName = null;
if (address != 0 && (address & 0xFFFF_FF00) == 0) {
// field 6: the domain name of the host we want to contact, variable length, terminated with a null (0x00)
StringBuilder domainNameBuilder = new StringBuilder(64);
for (int value = sourceInput.read(); value != 0; value = sourceInput.read()) {
domainNameBuilder.append((char) value);
}
domainName = domainNameBuilder.toString();
}
Socket targetSocket;
try {
if (domainName != null) {
targetSocket = new Socket(domainName, port);
}
else {
targetSocket = new Socket(InetAddresses.fromInteger(address), port);
}
}
catch (IOException e) {
// could not resolve name or open socket
responseSocks4(sourceOutput, SOCKS_4_FAILED, 0, 0);
return;
}
InputStream targetInput = targetSocket.getInputStream();
OutputStream targetOutput = targetSocket.getOutputStream();
// send success message
responseSocks4(sourceOutput, SOCKS_4_SUCCESS, port, InetAddresses.coerceToInteger(targetSocket.getInetAddress()));
proxyData(sourceInput, sourceOutput, targetInput, targetOutput);
}
private void responseSocks4(DataOutputStream output, int status, int port, int address)
throws IOException
{
ByteArrayDataOutput sourceOutput = ByteStreams.newDataOutput();
// field 1: null byte
sourceOutput.write(0);
// field 2: status, 1 byte:
sourceOutput.write(status);
// field 3: network byte order port number, 2 bytes
sourceOutput.writeShort(port);
// field 4: network byte order IP address, 4 bytes
sourceOutput.writeInt(address);
// write all at once to avoid Jetty bug
// TODO: remove this when fixed in Jetty
output.write(sourceOutput.toByteArray());
}
private void socks5(DataInputStream sourceInput, DataOutputStream sourceOutput)
{
// adding socks5 no_auth support would be trivial, but we need a client to test it
throw new UnsupportedOperationException();
}
private void proxyData(InputStream sourceInput, OutputStream sourceOutput, InputStream targetInput, OutputStream targetOutput)
{
// pipe in to out and out to in
List<ListenableFuture<?>> jobs = ImmutableList.of(
executor.submit(new Pipe(sourceInput, targetOutput)),
executor.submit(new Pipe(targetInput, sourceOutput)));
// close socket when both jobs finish
Futures.addCallback(Futures.allAsList(jobs), new FutureCallback<List<Object>>()
{
@Override
public void onSuccess(List<Object> result)
{
closeIgnoreException(socket);
}
@Override
public void onFailure(Throwable ignored)
{
closeIgnoreException(socket);
}
});
}
}
private static class Pipe
implements Runnable
{
private final InputStream in;
private final OutputStream out;
private Pipe(InputStream in, OutputStream out)
{
this.in = in;
this.out = out;
}
@Override
public void run()
{
try {
ByteStreams.copy(in, out);
}
catch (IOException e) {
// ignored nothing we can do about this
}
finally {
closeIgnoreException(in);
closeIgnoreException(out);
}
}
}
private static void closeIgnoreException(Closeable closeable)
{
if (closeable == null) {
return;
}
try {
closeable.close();
}
catch (IOException ignored) {
// nothing we can do about this
}
}
}