package redis.netty.client;
import com.google.common.base.Charsets;
import com.google.common.primitives.UnsignedBytes;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory;
import redis.Command;
import redis.netty.BulkReply;
import redis.netty.ErrorReply;
import redis.netty.MultiBulkReply;
import redis.netty.RedisDecoder;
import redis.netty.RedisEncoder;
import redis.netty.Reply;
import spullara.util.concurrent.Promise;
import spullara.util.functions.Block;
import java.io.BufferedReader;
import java.io.StringReader;
import java.net.InetSocketAddress;
import java.util.Comparator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Used by the generated code to execute commands.
*/
public class RedisClientBase {
private static final byte[] MESSAGE = "message".getBytes(Charsets.US_ASCII);
private static final byte[] PMESSAGE = "pmessage".getBytes(Charsets.US_ASCII);
private static final byte[] SUBSCRIBE = "subscribe".getBytes(Charsets.US_ASCII);
private static final byte[] PSUBSCRIBE = "psubscribe".getBytes(Charsets.US_ASCII);
private static final byte[] UNSUBSCRIBE = "unsubscribe".getBytes(Charsets.US_ASCII);
private static final byte[] PUNSUBSCRIBE = "punsubscribe".getBytes(Charsets.US_ASCII);
private Channel channel;
private Queue<Promise> queue;
private ExecutorService executor;
public static <T extends RedisClientBase> Promise<T> connect(String hostname, int port) {
return connect(hostname, port, (T) new RedisClientBase(), Executors.newCachedThreadPool());
}
private static final Pattern versionMatcher = Pattern.compile(
"([0-9]+)\\.([0-9]+)(\\.([0-9]+))?");
protected int version = 9999999;
protected void parseInfo(BulkReply info) {
try {
BufferedReader br = new BufferedReader(new StringReader(info.asUTF8String()));
String line;
while ((line = br.readLine()) != null) {
int index = line.indexOf(':');
if (index != -1) {
String name = line.substring(0, index);
String value = line.substring(index + 1);
if ("redis_version".equals(name)) {
this.version = parseVersion(value);
}
}
}
} catch (Exception re) {
// Server requires AUTH, check later
}
}
protected static int parseVersion(String value) {
int version = 0;
Matcher matcher = versionMatcher.matcher(value);
if (matcher.matches()) {
String major = matcher.group(1);
String minor = matcher.group(2);
String patch = matcher.group(4);
version = 100 * Integer.parseInt(minor) + 10000 * Integer.parseInt(major);
if (patch != null) {
version += Integer.parseInt(patch);
}
}
return version;
}
public static <T extends RedisClientBase> Promise<T> connect(String hostname, int port, final T redisClient, final ExecutorService executor) {
final ClientBootstrap cb = new ClientBootstrap(new NioClientSocketChannelFactory(executor, executor));
final Queue<Promise> queue = new LinkedTransferQueue<>();
final SimpleChannelUpstreamHandler handler = new SimpleChannelUpstreamHandler() {
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
Object message = e.getMessage();
if (queue.isEmpty()) {
if (message instanceof MultiBulkReply) {
redisClient.handleMessage(message);
} else {
// Need some way to notify
}
} else {
Promise poll = queue.poll();
if (message instanceof ErrorReply) {
poll.setException(new RedisException(((ErrorReply) message).data()));
} else {
poll.set(message);
}
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
if (queue.isEmpty()) {
// Needed for pub/sub?
} else {
Promise poll = queue.poll();
poll.setException(e.getCause());
}
}
};
final RedisEncoder encoder = new RedisEncoder();
final RedisDecoder decoder = new RedisDecoder();
cb.setPipelineFactory(new ChannelPipelineFactory() {
@Override
public ChannelPipeline getPipeline() throws Exception {
ChannelPipeline pipeline = Channels.pipeline();
pipeline.addLast("redisEncoder", encoder);
pipeline.addLast("redisDecoder", decoder);
pipeline.addLast("result", handler);
return pipeline;
}
});
final Promise<T> redisClientBasePromise = new Promise<>();
cb.connect(new InetSocketAddress(hostname, port)).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture channelFuture) throws Exception {
if (channelFuture.isSuccess()) {
redisClient.init(channelFuture.getChannel(), queue, executor);
redisClient.execute(BulkReply.class, new Command("INFO")).onSuccess(new Block<BulkReply>() {
@Override
public void apply(BulkReply bulkReply) {
redisClient.parseInfo(bulkReply);
redisClientBasePromise.set(redisClient);
}
}).onFailure(new Block<Throwable>() {
@Override
public void apply(Throwable throwable) {
redisClientBasePromise.setException(throwable);
}
});
} else if (channelFuture.isCancelled()) {
redisClientBasePromise.cancel(true);
} else {
redisClientBasePromise.setException(channelFuture.getCause());
}
}
});
return redisClientBasePromise;
}
protected RedisClientBase() {
}
protected void init(Channel channel, Queue<Promise> queue, ExecutorService executor) {
this.channel = channel;
this.queue = queue;
this.executor = executor;
}
public Promise<Void> close() {
final Promise<Void> closed = new Promise<>();
channel.close().addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture channelFuture) throws Exception {
if (channelFuture.isSuccess()) {
closed.set(null);
} else if (channelFuture.isCancelled()) {
closed.cancel(true);
} else {
closed.setException(channelFuture.getCause());
}
}
});
return closed;
}
private final Semaphore writerLock = new Semaphore(1);
protected <T> Promise<T> execute(final Class<T> clazz, final Command command) {
final Promise<T> reply = new Promise<T>() {
@Override
public void set(T value) {
// Check the type and fail if the wrong type
if (!clazz.isInstance(value)) {
setException(new RedisException("Incorrect type for " + value + " should be " + clazz.getName() + " but is " + value.getClass().getName()));
} else {
super.set(value);
}
}
};
if (subscribed.get()) {
reply.setException(new RedisException("Already subscribed, cannot send this command"));
} else {
executor.submit(new Runnable() {
@Override
public void run() {
ChannelFuture write;
synchronized (channel) {
queue.add(reply);
write = channel.write(command);
write.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
// Netty doesn't call these in order
} else if (future.isCancelled()) {
reply.cancel(true);
} else {
reply.setException(future.getCause());
}
}
});
}
}
});
}
return reply;
}
// Publish/subscribe section
private AtomicBoolean subscribed = new AtomicBoolean();
private void subscribed() {
subscribed.set(true);
}
/**
* Subscribes the client to the specified channels.
*
* @param subscriptions
*/
public Promise<Void> subscribe(Object... subscriptions) {
subscribed();
Promise<Void> result = new Promise<>();
channel.write(new Command(SUBSCRIBE, subscriptions)).addListener(wrapSubscribe(result));
return result;
}
private ChannelFutureListener wrapSubscribe(final Promise<Void> result) {
return new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
result.set(null);
} else if (future.isCancelled()) {
result.cancel(true);
} else {
result.setException(future.getCause());
}
}
};
}
/**
* Subscribes the client to the specified patterns.
*
* @param subscriptions
*/
public Promise<Void> psubscribe(Object... subscriptions) {
subscribed();
Promise<Void> result = new Promise<>();
channel.write(new Command(PSUBSCRIBE, subscriptions)).addListener(wrapSubscribe(result));
return result;
}
/**
* Unsubscribes the client to the specified channels.
*
* @param subscriptions
*/
public Promise<Void> unsubscribe(Object... subscriptions) {
subscribed();
Promise<Void> result = new Promise<>();
channel.write(new Command(UNSUBSCRIBE, subscriptions)).addListener(wrapSubscribe(result));
return result;
}
/**
* Unsubscribes the client to the specified patterns.
*
* @param subscriptions
*/
public Promise<Void> punsubscribe(Object... subscriptions) {
subscribed();
Promise<Void> result = new Promise<>();
channel.write(new Command(PUNSUBSCRIBE, subscriptions)).addListener(wrapSubscribe(result));
return result;
}
private List<ReplyListener> replyListeners = new CopyOnWriteArrayList<>();
/**
* Add a reply listener to this client for subscriptions.
*/
public void addListener(ReplyListener replyListener) {
replyListeners.add(replyListener);
}
/**
* Remove a reply listener from this client.
*/
public boolean removeListener(ReplyListener replyListener) {
return replyListeners.remove(replyListener);
}
private static Comparator<byte[]> BYTES = UnsignedBytes.lexicographicalComparator();
protected void handleMessage(Object message) {
MultiBulkReply reply = (MultiBulkReply) message;
Reply[] data = reply.data();
if (data.length != 3 && data.length != 4) {
throw new RedisException("Invalid subscription messsage");
}
for (ReplyListener replyListener : replyListeners) {
byte[] type = getBytes(data[0].data());
byte[] data1 = getBytes(data[1].data());
Object data2 = data[2].data();
switch (type.length) {
case 7:
if (BYTES.compare(type, MESSAGE) == 0) {
replyListener.message(data1, getBytes(data2));
continue;
}
break;
case 8:
if (BYTES.compare(type, PMESSAGE) == 0) {
replyListener.pmessage(data1, (byte[]) data2, ((ChannelBuffer) data[3].data()).array());
continue;
}
break;
case 9:
if (BYTES.compare(type, SUBSCRIBE) == 0) {
replyListener.subscribed(data1, ((Number) data2).intValue());
continue;
}
break;
case 10:
if (BYTES.compare(type, PSUBSCRIBE) == 0) {
replyListener.psubscribed(data1, ((Number) data2).intValue());
continue;
}
break;
case 11:
if (BYTES.compare(type, UNSUBSCRIBE) == 0) {
replyListener.unsubscribed(data1, ((Number) data2).intValue());
continue;
}
break;
case 12:
if (BYTES.compare(type, PUNSUBSCRIBE) == 0) {
replyListener.punsubscribed(data1, ((Number) data2).intValue());
continue;
}
break;
default:
break;
}
close();
}
}
private byte[] getBytes(Object data2) {
ChannelBuffer d = (ChannelBuffer) data2;
byte[] bytes = new byte[d.readableBytes()];
d.getBytes(d.readerIndex(), bytes);
return bytes;
}
}