package com.thimbleware.jmemcached.protocol.binary;
import com.thimbleware.jmemcached.protocol.Command;
import com.thimbleware.jmemcached.protocol.ResponseMessage;
import com.thimbleware.jmemcached.protocol.exceptions.UnknownCommandException;
import com.thimbleware.jmemcached.CacheElement;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.ByteOrder;
import java.util.Set;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
*
*/
// TODO refactor so this can be unit tested separate from netty? scalacheck?
@ChannelHandler.Sharable
public class MemcachedBinaryResponseEncoder<CACHE_ELEMENT extends CacheElement> extends SimpleChannelUpstreamHandler {
private ConcurrentHashMap<Integer, ChannelBuffer> corkedBuffers = new ConcurrentHashMap<Integer, ChannelBuffer>();
final Logger logger = LoggerFactory.getLogger(MemcachedBinaryResponseEncoder.class);
public static enum ResponseCode {
OK(0x0000),
KEYNF(0x0001),
KEYEXISTS(0x0002),
TOOLARGE(0x0003),
INVARG(0x0004),
NOT_STORED(0x0005),
UNKNOWN(0x0081),
OOM(0x00082);
public short code;
ResponseCode(int code) {
this.code = (short)code;
}
}
public ResponseCode getStatusCode(ResponseMessage command) {
Command cmd = command.cmd.cmd;
if (cmd == Command.GET || cmd == Command.GETS) {
return ResponseCode.OK;
} else if (cmd == Command.SET || cmd == Command.CAS || cmd == Command.ADD || cmd == Command.REPLACE || cmd == Command.APPEND || cmd == Command.PREPEND) {
switch (command.response) {
case EXISTS:
return ResponseCode.KEYEXISTS;
case NOT_FOUND:
return ResponseCode.KEYNF;
case NOT_STORED:
return ResponseCode.NOT_STORED;
case STORED:
return ResponseCode.OK;
}
} else if (cmd == Command.INCR || cmd == Command.DECR) {
return command.incrDecrResponse == null ? ResponseCode.KEYNF : ResponseCode.OK;
} else if (cmd == Command.DELETE) {
switch (command.deleteResponse) {
case DELETED:
return ResponseCode.OK;
case NOT_FOUND:
return ResponseCode.KEYNF;
}
} else if (cmd == Command.STATS) {
return ResponseCode.OK;
} else if (cmd == Command.VERSION) {
return ResponseCode.OK;
} else if (cmd == Command.FLUSH_ALL) {
return ResponseCode.OK;
}
return ResponseCode.UNKNOWN;
}
public ChannelBuffer constructHeader(MemcachedBinaryCommandDecoder.BinaryCommand bcmd, ChannelBuffer extrasBuffer, ChannelBuffer keyBuffer, ChannelBuffer valueBuffer, short responseCode, int opaqueValue, long casUnique) {
// take the ResponseMessage and turn it into a binary payload.
ChannelBuffer header = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 24);
header.writeByte((byte)0x81); // magic
header.writeByte(bcmd.code); // opcode
short keyLength = (short) (keyBuffer != null ? keyBuffer.capacity() :0);
header.writeShort(keyLength);
int extrasLength = extrasBuffer != null ? extrasBuffer.capacity() : 0;
header.writeByte((byte) extrasLength); // extra length = flags + expiry
header.writeByte((byte)0); // data type unused
header.writeShort(responseCode); // status code
int dataLength = valueBuffer != null ? valueBuffer.capacity() : 0;
header.writeInt(dataLength + keyLength + extrasLength); // data length
header.writeInt(opaqueValue); // opaque
header.writeLong(casUnique);
return header;
}
/**
* Handle exceptions in protocol processing. Exceptions are either client or internal errors. Report accordingly.
*
* @param ctx
* @param e
* @throws Exception
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
try {
throw e.getCause();
} catch (UnknownCommandException unknownCommand) {
if (ctx.getChannel().isOpen())
ctx.getChannel().write(constructHeader(MemcachedBinaryCommandDecoder.BinaryCommand.Noop, null, null, null, (short)0x0081, 0, 0));
} catch (Throwable err) {
logger.error("error", err);
if (ctx.getChannel().isOpen())
ctx.getChannel().close();
}
}
@Override
@SuppressWarnings("unchecked")
public void messageReceived(ChannelHandlerContext channelHandlerContext, MessageEvent messageEvent) throws Exception {
ResponseMessage<CACHE_ELEMENT> command = (ResponseMessage<CACHE_ELEMENT>) messageEvent.getMessage();
Object additional = messageEvent.getMessage();
MemcachedBinaryCommandDecoder.BinaryCommand bcmd = MemcachedBinaryCommandDecoder.BinaryCommand.forCommandMessage(command.cmd);
// write extras == flags & expiry
ChannelBuffer extrasBuffer = null;
// write key if there is one
ChannelBuffer keyBuffer = null;
if (bcmd.addKeyToResponse && command.cmd.keys != null && command.cmd.keys.size() != 0) {
keyBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, command.cmd.keys.get(0).getBytes());
}
// write value if there is one
ChannelBuffer valueBuffer = null;
if (command.elements != null) {
extrasBuffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 4);
CacheElement element = command.elements[0];
extrasBuffer.writeShort((short) (element != null ? element.getExpire() : 0));
extrasBuffer.writeShort((short) (element != null ? element.getFlags() : 0));
if ((command.cmd.cmd == Command.GET || command.cmd.cmd == Command.GETS)) {
if (element != null) {
valueBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, element.getData());
} else {
valueBuffer = ChannelBuffers.buffer(0);
}
} else if (command.cmd.cmd == Command.INCR || command.cmd.cmd == Command.DECR) {
valueBuffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 8);
valueBuffer.writeLong(command.incrDecrResponse);
}
} else if (command.cmd.cmd == Command.INCR || command.cmd.cmd == Command.DECR) {
valueBuffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, 8);
valueBuffer.writeLong(command.incrDecrResponse);
}
long casUnique = 0;
if (command.elements != null && command.elements.length != 0 && command.elements[0] != null) {
casUnique = command.elements[0].getCasUnique();
}
// stats is special -- with it, we write N times, one for each stat, then an empty payload
if (command.cmd.cmd == Command.STATS) {
// first uncork any corked buffers
if (corkedBuffers.containsKey(command.cmd.opaque)) uncork(command.cmd.opaque, messageEvent.getChannel());
for (Map.Entry<String, Set<String>> statsEntries : command.stats.entrySet()) {
for (String stat : statsEntries.getValue()) {
keyBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, statsEntries.getKey().getBytes(MemcachedBinaryCommandDecoder.USASCII));
valueBuffer = ChannelBuffers.wrappedBuffer(ByteOrder.BIG_ENDIAN, stat.getBytes(MemcachedBinaryCommandDecoder.USASCII));
ChannelBuffer headerBuffer = constructHeader(bcmd, extrasBuffer, keyBuffer, valueBuffer, getStatusCode(command).code, command.cmd.opaque, casUnique);
writePayload(messageEvent, extrasBuffer, keyBuffer, valueBuffer, headerBuffer);
}
}
keyBuffer = null;
valueBuffer = null;
ChannelBuffer headerBuffer = constructHeader(bcmd, extrasBuffer, keyBuffer, valueBuffer, getStatusCode(command).code, command.cmd.opaque, casUnique);
writePayload(messageEvent, extrasBuffer, keyBuffer, valueBuffer, headerBuffer);
} else {
ChannelBuffer headerBuffer = constructHeader(bcmd, extrasBuffer, keyBuffer, valueBuffer, getStatusCode(command).code, command.cmd.opaque, casUnique);
// write everything
// is the command 'quiet?' if so, then we append to our 'corked' buffer until a non-corked command comes along
if (bcmd.noreply) {
int totalCapacity = headerBuffer.capacity() + (extrasBuffer != null ? extrasBuffer.capacity() : 0)
+ (keyBuffer != null ? keyBuffer.capacity() : 0) + (valueBuffer != null ? valueBuffer.capacity() : 0);
ChannelBuffer corkedResponse = cork(command.cmd.opaque, totalCapacity);
corkedResponse.writeBytes(headerBuffer);
if (extrasBuffer != null)
corkedResponse.writeBytes(extrasBuffer);
if (keyBuffer != null)
corkedResponse.writeBytes(keyBuffer);
if (valueBuffer != null)
corkedResponse.writeBytes(valueBuffer);
} else {
// first write out any corked responses
if (corkedBuffers.containsKey(command.cmd.opaque)) uncork(command.cmd.opaque, messageEvent.getChannel());
writePayload(messageEvent, extrasBuffer, keyBuffer, valueBuffer, headerBuffer);
}
}
}
private ChannelBuffer cork(int opaque, int totalCapacity) {
if (corkedBuffers.containsKey(opaque)) {
ChannelBuffer corkedResponse = corkedBuffers.get(opaque);
ChannelBuffer oldBuffer = corkedResponse;
corkedResponse = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, totalCapacity + corkedResponse.capacity());
corkedResponse.writeBytes(oldBuffer);
oldBuffer.clear();
corkedBuffers.remove(opaque);
corkedBuffers.put(opaque, corkedResponse);
return corkedResponse;
} else {
ChannelBuffer buffer = ChannelBuffers.buffer(ByteOrder.BIG_ENDIAN, totalCapacity);
corkedBuffers.put(opaque, buffer);
return buffer;
}
}
private void uncork(int opaque, Channel channel) {
ChannelBuffer corkedBuffer = corkedBuffers.get(opaque);
assert corkedBuffer != null;
channel.write(corkedBuffer);
corkedBuffers.remove(opaque);
}
private void writePayload(MessageEvent messageEvent, ChannelBuffer extrasBuffer, ChannelBuffer keyBuffer, ChannelBuffer valueBuffer, ChannelBuffer headerBuffer) {
if (messageEvent.getChannel().isOpen()) {
messageEvent.getChannel().write(headerBuffer);
if (extrasBuffer != null)
messageEvent.getChannel().write(extrasBuffer);
if (keyBuffer != null)
messageEvent.getChannel().write(keyBuffer);
if (valueBuffer != null)
messageEvent.getChannel().write(valueBuffer);
}
}
}