package play.server.hybi10;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.frame.CorruptedFrameException;
import org.jboss.netty.handler.codec.http.websocket.DefaultWebSocketFrame;
import org.jboss.netty.handler.codec.replay.ReplayingDecoder;
import java.util.ArrayList;
import java.util.List;
public class Hybi10WebSocketFrameDecoder extends ReplayingDecoder<Hybi10WebSocketFrameDecoder.State> {
private static final byte OPCODE_CONT = 0x0;
private static final byte OPCODE_TEXT = 0x1;
private static final byte OPCODE_BINARY = 0x2;
private static final byte OPCODE_CLOSE = 0x8;
private static final byte OPCODE_PING = 0x9;
private static final byte OPCODE_PONG = 0xA;
private Byte fragmentOpcode;
private Byte opcode = null;
private int currentFrameLength;
private ChannelBuffer maskingKey;
private List<ChannelBuffer> frames = new ArrayList<ChannelBuffer>();
public static enum State {
FRAME_START,
PARSING_LENGTH,
MASKING_KEY,
PARSING_LENGTH_2,
PARSING_LENGTH_3,
PAYLOAD
}
public Hybi10WebSocketFrameDecoder() {
super(State.FRAME_START);
}
@Override
protected Object decode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, State state) throws Exception {
switch (state) {
case FRAME_START:
byte b = buffer.readByte();
byte fin = (byte) (b & 0x80);
byte reserved = (byte) (b & 0x70);
byte opcode = (byte) (b & 0x0F);
if (reserved != 0) {
throw new CorruptedFrameException("Reserved bits set: " + bits(reserved));
}
if (!isOpcode(opcode)) {
throw new CorruptedFrameException("Invalid opcode " + hex(opcode));
}
if (fin == 0) {
if (fragmentOpcode == null) {
if (!isDataOpcode(opcode)) {
throw new CorruptedFrameException("Fragmented frame with invalid opcode " + hex(opcode));
}
fragmentOpcode = opcode;
} else if (opcode != OPCODE_CONT) {
throw new CorruptedFrameException("Continuation frame with invalid opcode " + hex(opcode));
}
} else {
if (fragmentOpcode != null) {
if (!isControlOpcode(opcode) && opcode != OPCODE_CONT) {
throw new CorruptedFrameException("Final frame with invalid opcode " + hex(opcode));
}
} else if (opcode == OPCODE_CONT) {
throw new CorruptedFrameException("Final frame with invalid opcode " + hex(opcode));
}
this.opcode = opcode;
}
checkpoint(State.PARSING_LENGTH);
case PARSING_LENGTH:
b = buffer.readByte();
byte masked = (byte) (b & 0x80);
if (masked == 0) {
throw new CorruptedFrameException("Unmasked frame received");
}
int length = (byte) (b & 0x7F);
if (length < 126) {
currentFrameLength = length;
checkpoint(State.MASKING_KEY);
} else if (length == 126) {
checkpoint(State.PARSING_LENGTH_2);
} else if (length == 127) {
checkpoint(State.PARSING_LENGTH_3);
}
return null;
case PARSING_LENGTH_2:
int s = buffer.readUnsignedShort();
currentFrameLength = s;
checkpoint(State.MASKING_KEY);
return null;
case PARSING_LENGTH_3:
currentFrameLength = (int) buffer.readLong();
checkpoint(State.MASKING_KEY);
case MASKING_KEY:
maskingKey = buffer.readBytes(4);
checkpoint(State.PAYLOAD);
case PAYLOAD:
ChannelBuffer frame = buffer.readBytes(currentFrameLength);
unmask(frame);
if (this.opcode == OPCODE_CONT) {
this.opcode = fragmentOpcode;
frames.add(frame);
frame = channel.getConfig().getBufferFactory().getBuffer(0);
for (ChannelBuffer channelBuffer : frames) {
frame.ensureWritableBytes(channelBuffer.readableBytes());
frame.writeBytes(channelBuffer);
}
this.fragmentOpcode = null;
frames.clear();
checkpoint(State.FRAME_START);
return null;
}
try {
if (this.opcode == OPCODE_TEXT) {
return new DefaultWebSocketFrame(0x00, frame);
} else if (this.opcode == OPCODE_BINARY) {
return new DefaultWebSocketFrame(0xFF, frame);
} else if (this.opcode == OPCODE_PING) {
return new Pong(0x00, frame);
} else if (this.opcode == OPCODE_PONG) {
return null;
} else if (this.opcode == OPCODE_CLOSE) {
return null;
}
} finally {
checkpoint(State.FRAME_START);
}
default:
throw new Error("Shouldn't reach here.");
}
}
private void unmask(ChannelBuffer frame) {
byte[] bytes = frame.array();
for (int i = 0; i < bytes.length; i++) {
int b = frame.getByte(i) ^ maskingKey.getByte(i % 4);
frame.setByte(i, b);
}
}
private String bits(byte b) {
return Integer.toBinaryString(b).substring(24);
}
private String hex(byte b) {
return Integer.toHexString(b);
}
private boolean isOpcode(int opcode) {
return opcode == OPCODE_CONT ||
opcode == OPCODE_TEXT ||
opcode == OPCODE_BINARY ||
opcode == OPCODE_CLOSE ||
opcode == OPCODE_PING ||
opcode == OPCODE_PONG;
}
private boolean isControlOpcode(int opcode) {
return opcode == OPCODE_CLOSE ||
opcode == OPCODE_PING ||
opcode == OPCODE_PONG;
}
private boolean isDataOpcode(int opcode) {
return opcode == OPCODE_TEXT ||
opcode == OPCODE_BINARY;
}
}