package edu.brown.protorpc;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import org.junit.Before;
import org.junit.Test;
import ca.evanjones.protorpc.Counter;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.CodedOutputStream;
import edu.brown.net.MockByteChannel;
import edu.brown.net.NonBlockingConnection;
public class ProtoConnectionTest {
MockByteChannel channel;
NonBlockingConnection nonblock;
ProtoConnection connection;
@Before
public void setUp() throws IOException {
channel = new MockByteChannel();
connection = new ProtoConnection(new NonBlockingConnection(null, channel));
}
@Test
public void testTryWrite() throws IOException {
Counter.Value v = Counter.Value.newBuilder().setValue(42).build();
assertFalse(connection.tryWrite(v));
CodedInputStream in = CodedInputStream.newInstance(channel.lastWrites.get(0));
int length = in.readRawLittleEndian32();
assertEquals(length, channel.lastWrites.get(0).length - 4);
Counter.Value w = Counter.Value.parseFrom(in);
assertEquals(v, w);
assertTrue(in.isAtEnd());
channel.clear();
channel.numBytesToAccept = 3;
assertTrue(connection.tryWrite(v));
channel.numBytesToAccept = -1;
assertFalse(connection.writeAvailable());
assertEquals(2, channel.lastWrites.size());
}
@Test
public void testReadBufferedMessage() throws IOException {
Counter.Value.Builder builder = Counter.Value.newBuilder();
assertTrue(connection.readAllAvailable());
assertFalse(connection.readBufferedMessage(builder));
Counter.Value v = Counter.Value.newBuilder().setValue(42).build();
byte[] all = makeConnectionMessage(v);
byte[] fragment1 = new byte[3];
System.arraycopy(all, 0, fragment1, 0, fragment1.length);
byte[] fragment2 = new byte[all.length - fragment1.length];
System.arraycopy(all, fragment1.length, fragment2, 0, fragment2.length);
channel.setNextRead(fragment1);
assertTrue(connection.readAllAvailable());
assertTrue(connection.readAllAvailable());
assertFalse(connection.readBufferedMessage(builder));
channel.setNextRead(fragment2);
assertTrue(connection.readAllAvailable());
assertTrue(connection.readBufferedMessage(builder));
assertEquals(v, builder.build());
channel.end = true;
assertFalse(connection.readBufferedMessage(builder));
connection.close();
assertTrue(channel.closed);
}
private static byte[] makeConnectionMessage(Counter.Value value)
throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
CodedOutputStream codedOutput = CodedOutputStream.newInstance(out);
codedOutput.writeRawLittleEndian32(value.getSerializedSize());
value.writeTo(codedOutput);
codedOutput.flush();
byte[] all = out.toByteArray();
return all;
}
@Test
public void testInputStreamLimitReset() throws IOException {
// Build a ~40 MB string
final int MEGABYTE = 1 << 20;
final int CODED_INPUT_LIMIT = 64;
char[] megabyte = new char[MEGABYTE];
for (int i = 0; i < megabyte.length; ++i) {
megabyte[i] = 'a';
}
String megaString = new String(megabyte);
Counter.Value megaValue = Counter.Value.newBuilder()
.setName(megaString.toString())
.setValue(42)
.build();
byte[] all = makeConnectionMessage(megaValue);
Counter.Value.Builder builder = Counter.Value.newBuilder();
for (int i = 0; i < CODED_INPUT_LIMIT * 2; ++i) {
channel.setNextRead(all);
assertTrue(connection.readAllAvailable());
assertTrue(connection.readBufferedMessage(builder));
builder.clear();
}
}
}