package pl.icedev.rmi;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.json.simple.parser.ParseException;
import pl.icedev.nio.NIOConnection;
import pl.icedev.rmi.RMIServer.RMIInterface;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.util.HashMap;
import java.util.Map;
public class RMIConnection extends NIOConnection {
private JSONParser parser;
private RMIServer server;
public RMIConnection(RMIServer server, SocketChannel channel, Selector selector)
throws IOException {
super(channel, selector);
this.server = server;
parser = new JSONParser();
}
private void sendResult(Object result) {
JSONObject json = new JSONObject();
json.put("result", result);
writePacket(json.toJSONString());
}
private void sendError(RMIFault fault) {
JSONObject json = new JSONObject();
json.put("error", fault.toString());
writePacket(json.toJSONString());
}
private void sendException(Exception e) {
JSONObject json = new JSONObject();
json.put("error", RMIFault.EXCEPTION.toString());
json.put("message", e.getMessage());
json.put("exception", e.getClass().getCanonicalName());
// TODO stack tace later and causes
writePacket(json.toJSONString());
}
int idx;
Map<Integer, RMIInterface> int2itf = new HashMap<>();
Map<RMIInterface, Integer> itf2int = new HashMap<>();
private void handle(JSONObject json) {
if (json.containsKey("id")) {
int id = TypeUtil.toInt(json.get("id"));
String method = (String) json.get("method");
JSONArray params = (JSONArray) json.get("args");
RMIInterface interf = int2itf.get(id);
if (interf == null) {
sendError(RMIFault.ITF_ID_NOT_FOUND);
return;
}
try {
sendResult(interf.invoke(method, params.toArray()));
} catch (InvocationTargetException e) {
Throwable target = e.getTargetException();
if(target instanceof Exception) {
sendException((Exception) target);
} else {
e.printStackTrace();
sendError(RMIFault.INTERNAL_ERROR);
}
} catch (Exception e) {
e.printStackTrace();
sendError(RMIFault.INTERNAL_ERROR);
}
} else if (json.containsKey("request")) {
String itname = (String) json.get("request");
RMIInterface interf = server.interfaces.get(itname);
if (interf == null) {
sendError(RMIFault.ITF_NOT_IMPLEMENTED);
System.out.println("Got interfaces: ");
for (String s : server.interfaces.keySet()) {
System.out.print(s + ",");
}
System.out.println();
return;
}
Integer num = itf2int.get(interf);
if (num == null) {
num = idx++;
itf2int.put(interf, num);
int2itf.put(num, interf);
}
JSONObject obj = new JSONObject();
obj.put("interface", interf.methods.keySet().toArray());
obj.put("id", num);
writePacket(obj.toJSONString());
} else {
sendError(RMIFault.INVALID_MESSAGE);
}
}
@Override
public void onMessage(String msg) {
try {
JSONObject json = (JSONObject) parser.parse(msg);
handle(json);
} catch (ParseException e) {
e.printStackTrace();
close();
}
}
@Override
public void onDisconnect() {
}
}