package com.peterhi.remote;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.SocketAddress;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import com.peterhi.Run;
import com.peterhi.Util;
import com.peterhi.io.PmInputStream;
import com.peterhi.io.PmOutputStream;
import com.peterhi.net.rudp.RudpEvent;
import com.peterhi.net.rudp.RudpListener;
import com.peterhi.net.rudp.RudpServer;
import com.peterhi.property.PropException;
import com.peterhi.property.PropModel;
public final class RemoteRegistry implements RudpListener {
private final RudpServer server;
private final List<RegistryEntry> entries;
private final ScheduledExecutorService threadPool;
private int timeout;
public RemoteRegistry(RudpServer server, int timeout) {
this.server = server;
this.server.addRudpListener(this);
this.entries = new ArrayList<RegistryEntry>();
this.threadPool = Executors.newScheduledThreadPool(15);
this.timeout = timeout;
}
class Dispatcher implements InvocationHandler {
private final Class<?> type;
private final SocketAddress address;
private final int objectId;
public Dispatcher(Class<?> type, SocketAddress address, int objectId) {
this.type = type;
this.address = address;
this.objectId = objectId;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args)
throws Throwable {
String mname = method.getName();
Class<?> rtype = method.getReturnType();
Class<?>[] ptypes = method.getParameterTypes();
if (mname.equals("hashCode") && rtype == int.class &&
ptypes.length == 0) {
return invokeHashCode(proxy, args);
}
if (mname.equals("equals") && rtype == boolean.class &&
ptypes.length == 1 && ptypes[0] == Object.class) {
return invokeEquals(proxy, args);
}
if (mname.equals("toString") && rtype == String.class &&
ptypes.length == 0) {
return invokeToString(proxy, args);
}
return invokeRemote(proxy, method, args);
}
private Object invokeHashCode(Object proxy, Object[] args)
throws Throwable {
return System.identityHashCode(proxy);
}
private Object invokeEquals(Object proxy, Object[] args)
throws Throwable {
if (args[0] == null) {
return false;
}
Class<?> argClass = args[0].getClass();
if (!Proxy.isProxyClass(argClass)) {
return false;
}
Dispatcher dispatcher =
(Dispatcher )Proxy.getInvocationHandler(args[0]);
if (dispatcher == null) {
return false;
}
int hashCode0 = proxy.hashCode();
int hashCode1 = args[0].hashCode();
return hashCode0 == hashCode1;
}
private Object invokeToString(Object proxy, Object[] args)
throws Throwable {
int hashCode = proxy.hashCode();
String pattern = "{0}@{1}";
String string = MessageFormat.format(pattern, type.getName(),
Integer.toHexString(hashCode));
return string;
}
private Object invokeRemote(Object proxy, Method method,
Object[] args) throws Throwable {
final RemoteMethodInvocation request =
new RemoteMethodInvocation();
request.setObjectId(objectId);
request.setOperationId(request.hashCode());
request.setName(method.getName());
Object[] params = Arrays.copyOfRange(args, 1, args.length);
params = encode(params);
request.setParameters(params);
final RemoteMethodReturn[] reference =
new RemoteMethodReturn[1];
RudpListener listener = new RudpListener() {
@Override
public void received(RudpEvent event) {
PropModel sink;
try {
sink = read(event);
} catch (Exception ex) {
ex.printStackTrace();
return;
}
if (!(sink instanceof RemoteMethodReturn)) {
return;
}
RemoteMethodReturn response = (RemoteMethodReturn )sink;
if (response.getOperationId() != request.getOperationId()) {
return;
}
synchronized (reference) {
reference[0] = response;
reference.notifyAll();
}
}
};
server.addRudpListener(listener);
write(address, request, true);
synchronized (reference) {
if (reference[0] == null) {
try {
reference.wait(timeout);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
server.removeRudpListener(listener);
}
if (reference[0] == null || !reference[0].isSucceeded()) {
String message = "Cannot get return value while calling " +
method + ".";
if (reference[0] != null) {
message = reference[0].getMessage();
}
throw new RemoteException(message);
}
return reference[0].getValue();
}
}
public void publish(String name, Class<?> type, Remotable remotable) {
RegistryEntry entry = getEntry(name);
if (entry != null) {
String pattern = "The name \"{0}\" is already used.";
String message = MessageFormat.format(pattern, name);
throw new IllegalArgumentException(message);
}
synchronized (entries) {
entry = new RegistryEntry(name, type, remotable);
if (!entries.contains(entry)) {
entries.add(entry);
}
}
}
public <T extends Remotable> T register(Class<T> type, T remotable) {
RegistryEntry entry;
synchronized (entries) {
entry = getEntry(remotable.hashCode());
if (entry == null) {
entry = new RegistryEntry(null, type, remotable);
entries.add(entry);
}
}
return type.cast(entry.getRemotable());
}
public <T extends Remotable> T unregister(T remotable) {
RegistryEntry entry = getEntry(remotable.hashCode());;
if (entry == null) {
return null;
}
synchronized (entries) {
entries.remove(entry);
return remotable;
}
}
public <T extends Remotable> T subscribe(SocketAddress address,
String name, Class<T> type, int timeout) throws RemoteException {
final RemoteSubscriptionRequest request = new RemoteSubscriptionRequest();
request.setOperationId(request.hashCode());
request.setName(name);
final RemoteSubscriptionResponse[] reference =
new RemoteSubscriptionResponse[1];
RudpListener listener = new RudpListener() {
@Override
public void received(RudpEvent event) {
PropModel sink;
try {
sink = read(event);
} catch (Exception ex) {
ex.printStackTrace();
return;
}
if (!(sink instanceof RemoteSubscriptionResponse)) {
return;
}
RemoteSubscriptionResponse response =
(RemoteSubscriptionResponse )sink;
if (response.getOperationId() != request.getOperationId()) {
return;
}
synchronized (reference) {
reference[0] = response;
reference.notifyAll();
}
}
};
server.addRudpListener(listener);
write(address, request, true);
synchronized (reference) {
if (reference[0] == null) {
try {
reference.wait(timeout);
} catch (InterruptedException ex) {
ex.printStackTrace();
}
}
server.removeRudpListener(listener);
}
if (reference[0] == null || !reference[0].isSucceeded()) {
String message = "Not object returned.";
if (reference[0] != null) {
message = reference[0].getMessage();
}
throw new RemoteException(message);
}
ClassLoader loader = type.getClassLoader();
Class<?>[] interfaces = new Class<?>[] { type };
InvocationHandler invocationHandler = new Dispatcher(type, address,
reference[0].getObjectId());
return type.cast(Proxy.newProxyInstance(loader, interfaces,
invocationHandler));
}
public RegistryEntry getEntry(final String name) {
final int[][] arrays = Util.distribute(entries, 100, 5);
List<Callable<RegistryEntry>> tasks =
new ArrayList<Callable<RegistryEntry>>();
for (int i = 0; i < arrays.length; i++) {
final int index = i;
tasks.add(new Run<RegistryEntry>() {
@Override
protected void onRun() throws Exception {
int[] dimensions = arrays[index];
RegistryEntry[] array;
synchronized (entries) {
array = entries.subList(dimensions[0],
dimensions[0] + dimensions[1]).toArray(
new RegistryEntry[dimensions[1]]);
}
for (RegistryEntry element : array) {
if (element.getName().equals(name)) {
setResult(element);
return;
}
}
String message = "Entry not found.";
throw new Exception(message);
}
@Override
protected boolean handleUncaughtException(Exception ex) {
return false; // handle this exception manually
}
});
}
if (tasks.isEmpty()) {
return null;
}
try {
return threadPool.invokeAny(tasks);
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
return null;
}
}
public RegistryEntry getEntry(final int objectId) {
final int[][] arrays = Util.distribute(entries, 100, 5);
List<Callable<RegistryEntry>> tasks =
new ArrayList<Callable<RegistryEntry>>();
for (int i = 0; i < arrays.length; i++) {
final int index = i;
tasks.add(new Run<RegistryEntry>() {
@Override
protected void onRun() throws Exception {
int[] dimensions = arrays[index];
RegistryEntry[] array;
synchronized (entries) {
array = entries.subList(dimensions[0],
dimensions[0] + dimensions[1]).toArray(
new RegistryEntry[dimensions[1]]);
}
for (RegistryEntry element : array) {
if (element.getObjectId() == objectId) {
setResult(element);
return;
}
}
String message = "Entry not found.";
throw new Exception(message);
}
@Override
protected boolean handleUncaughtException(Exception ex) {
return false; // handle this exception manually
}
});
}
if (tasks.isEmpty()) {
return null;
}
try {
return threadPool.invokeAny(tasks);
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
return null;
}
}
@Override
public void received(final RudpEvent event) {
PropModel sink;
try {
sink = read(event);
} catch (Exception ex) {
ex.printStackTrace();
return;
}
if (sink instanceof RemoteSubscriptionRequest) {
RemoteSubscriptionRequest request =
(RemoteSubscriptionRequest )sink;
RemoteSubscriptionResponse response =
new RemoteSubscriptionResponse();
response.setOperationId(request.getOperationId());
response.setResult(RemoteResponse.ERROR_NOT_FOUND);
response.setMessage("Object not found.");
RegistryEntry entry = getEntry(request.getName());
if (entry != null) {
response.setResult(RemoteResponse.SUCCESS);
response.setMessage(null);
response.setObjectId(entry.getObjectId());
}
write(event.getSocketAddress(), response, true);
} else if (sink instanceof RemoteMethodInvocation) {
RemoteMethodInvocation request = (RemoteMethodInvocation )sink;
RemoteMethodReturn response = new RemoteMethodReturn();
response.setOperationId(request.getOperationId());
response.setResult(RemoteResponse.ERROR_NOT_FOUND);
response.setMessage("Object " + request.getObjectId() + ", or method " +
request.getName() + " not found.");
RegistryEntry entry = getEntry(request.getObjectId());
if (entry != null) {
Remotable remotable = entry.getRemotable();
String name = request.getName();
Object[] params = request.getParameters();
Object[] actualParams = new Object[params.length + 1];
System.arraycopy(params, 0, actualParams, 1, params.length);
actualParams[0] = new RemoteContext(event.getSocketAddress());
Method method = findMethod(entry, name, actualParams);
if (method != null) {
try {
actualParams = decode(event.getSocketAddress(),
method.getParameterTypes(), actualParams, 8000);
Object ret = method.invoke(remotable, actualParams);
response.setResult(RemoteResponse.SUCCESS);
response.setMessage(null);
response.setValue(ret);
} catch (Exception ex) {
response.setResult(RemoteResponse.ERROR_UNKNOWN);
response.setMessage(ex.toString());
}
}
}
write(event.getSocketAddress(), response, true);
}
}
private void write(final SocketAddress address, final PropModel sink,
final boolean reliable) {
threadPool.submit(new Run<Void>() {
@Override
protected void onRun() throws Exception {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
PmOutputStream pos = new PmOutputStream(baos);
pos.writeModel(sink);
byte[] bytes = baos.toByteArray();
server.send(address, bytes, 0, bytes.length, reliable);
pos.close();
}
}.asRunnable());
}
private PropModel read(RudpEvent event) throws IOException,
ClassNotFoundException, PropException, IllegalAccessException,
InstantiationException {
InputStream stream = event.getInputStream();
PmInputStream pis = new PmInputStream(stream);
stream.mark(0);
PropModel sink = pis.readModel();
stream.reset();
pis.close();
return sink;
}
private Method findMethod(RegistryEntry entry, String name,
Object[] params) {
Class<?> type = entry.getRemotable().getClass();
Method[] methods = type.getMethods();
for (Method method : methods) {
String mname = method.getName();
Class<?>[] ptypes = method.getParameterTypes();
if (mname.equals(name) && ptypes.length == params.length) {
return method;
}
}
return null;
}
private Object[] encode(Object[] args) {
for (int i = 0; i < args.length; i++) {
if (args[i] instanceof Remotable) {
Remotable remotable = (Remotable )args[i];
args[i] = remotable.hashCode();
}
}
return args;
}
private Object[] decode(SocketAddress address, Class<?>[] types,
Object[] args, int timeout) {
int length = types.length;
for (int i = 0; i < length; i++) {
if (Remotable.class.isAssignableFrom(types[i]) &&
args[i] instanceof Integer) {
int objectId = (Integer )args[i];
RegistryEntry entry = getEntry(objectId);
if (entry != null) {
args[i] = entry.getRemotable();
} else {
ClassLoader loader = types[i].getClassLoader();
Class<?>[] interfaces = new Class<?>[] { types[i] };
InvocationHandler handler = new Dispatcher(types[i],
address, objectId);
args[i] = Proxy.newProxyInstance(loader, interfaces, handler);
}
}
}
return args;
}
}
final class RegistryEntry {
private final String name;
private final Class<?> type;
private final Remotable remotable;
public RegistryEntry(String name, Class<?> type, Remotable remotable) {
if (type == null) {
String message = "The type is null.";
throw new IllegalArgumentException(message);
}
if (remotable == null) {
String message = "The object is null.";
throw new IllegalArgumentException(message);
}
if (name == null) {
name = "";
}
this.name = name;
this.type = type;
this.remotable = remotable;
}
public String getName() {
return name;
}
public Class<?> getType() {
return type;
}
public Remotable getRemotable() {
return remotable;
}
public int getObjectId() {
return remotable.hashCode();
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((remotable == null) ? 0 : remotable.hashCode());
result = prime * result + ((name == null) ? 0 : name.hashCode());
result = prime * result + ((type == null) ? 0 : type.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
RegistryEntry other = (RegistryEntry) obj;
if (remotable == null) {
if (other.remotable != null)
return false;
} else if (!remotable.equals(other.remotable))
return false;
if (name == null) {
if (other.name != null)
return false;
} else if (!name.equals(other.name))
return false;
if (type == null) {
if (other.type != null)
return false;
} else if (!type.equals(other.type))
return false;
return true;
}
}