package cz.woitee.websockets;
import java.io.*;
import java.net.*;
import java.nio.charset.Charset;
import java.security.NoSuchAlgorithmException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import cz.woitee.websockets.streams.ServersideWebSocketInputStream;
import cz.woitee.websockets.streams.ServersideWebSocketOutputStream;
import cz.woitee.websockets.utils.UTF8String;
/**
* Serverside WebSocket. Listens on a port and returns connections as WebSockets.
*
* @author woitee
*
*/
public class ServerWebSocket {
//used charset should be UTF-8 or ASCII
protected static final String charsetName = "UTF-8";
protected static final Charset charset = Charset.forName(charsetName);
ServerSocket serverSock;
public ServerWebSocket(int port, int backlog, InetAddress bindAddr) throws IOException {
serverSock = new ServerSocket(port, backlog, bindAddr);
}
WebSocket socket;
//WebSocket with overloaded getOutputStream and getInputStream methods
//those return custom streams, that provide the WebSocket frame creation/parsing
public WebSocket accept() throws IOException, WebSocketException {
socket = new WebSocket(serverSock.accept()) {
private OutputStream outputStream = new ServersideWebSocketOutputStream(socket);
private InputStream inputStream = new ServersideWebSocketInputStream(socket);
@Override
public OutputStream getOutputStream() throws IOException {
return outputStream;
}
@Override
public InputStream getInputStream() throws IOException {
return inputStream;
}
};
initialHandshake();
return socket;
}
protected void initialHandshake() throws IOException, WebSocketException {
Socket rawSocket = socket.getUnderlyingSocket();
RequestData request = readData(rawSocket);
byte[] response = formResponse(request);
rawSocket.getOutputStream().write(response);
}
protected static class RequestData {
public LinkedHashMap<String, byte[]> map = new LinkedHashMap<String, byte[]>();
public byte[] completeData;
}
protected static RequestData readData(Socket socket) throws IOException {
InputStream in = socket.getInputStream();
RequestData ret = new RequestData();
String key = null;
ByteArrayOutputStream bytes = new ByteArrayOutputStream();
ByteArrayOutputStream allBytes = new ByteArrayOutputStream();
boolean skipNext = false, seenColon = false;
while(in.available() > 0) {
byte[] buf = new byte[in.available()];
in.read(buf);
for (byte b: buf) {
allBytes.write(b);
if (skipNext) {skipNext = false; continue;}
//handle special chars
if (b == '\r') continue;
if (b == ':' && !seenColon) {
key = bytes.toString(charsetName);
bytes = new ByteArrayOutputStream();
seenColon = true;
skipNext = true;
continue;
}
//decide if line is over
if (b == '\n') {
ret.map.put(key, bytes.toByteArray());
bytes = new ByteArrayOutputStream();
seenColon = false;
} else {
bytes.write(b);
}
}
}
ret.completeData = bytes.toByteArray();
return ret;
}
protected static byte[] formResponse(RequestData req) throws WebSocketException{
Set<Map.Entry<String, byte[]>> entrySet = req.map.entrySet();
//Filter some invalid requests
byte[] secWebSocketKey = req.map.get("Sec-WebSocket-Key");
if (entrySet.size() < 2 ||
//first line must start with GET
!new String(entrySet.iterator().next().getValue(), charset).startsWith("GET") ||
secWebSocketKey == null) {
throw new WebSocketException("Invalid request received.");
}
//Form actual response
try {
String message = "HTTP/1.1 101 Switching Protocols\r\n" +
"Upgrade: websocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Accept: " +
new UTF8String(secWebSocketKey).concat(new UTF8String("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
.toSHA1().toBase64().toString() +
"\r\n" +
"\r\n";
return new UTF8String(message).array();
} catch (UnsupportedEncodingException e) {
throw new WebSocketException("UTF-8 charset not available", e);
} catch (NoSuchAlgorithmException e) {
throw new WebSocketException("SHA-1 hashing not available.", e);
}
}
public void close() throws IOException {
serverSock.close();
}
}