package com.peterhi.net.remote;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.SocketAddress;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import com.peterhi.InvocationAdapter;
import com.peterhi.Range;
import com.peterhi.Util;
import com.peterhi.io.PmInputStream;
import com.peterhi.io.PmOutputStream;
import com.peterhi.net.rudp.RudpEvent;
import com.peterhi.net.rudp.RudpListener;
import com.peterhi.net.rudp.RudpServer;
import com.peterhi.property.PropModel;
public final class RmRegistry implements RudpListener {
private final RudpServer server;
private final List<RmEntry> entries;
private final ScheduledExecutorService workers;
private static final ThreadLocal<RmContext> context = new ThreadLocal<RmContext>();
private int timeout;
public RmRegistry(RudpServer server, int timeout) {
this.server = server;
this.server.addRudpListener(this);
this.entries = new ArrayList<RmEntry>();
this.workers = Executors.newScheduledThreadPool(15);
this.timeout = timeout;
}
final class Caller<T> extends InvocationAdapter<T> {
private final SocketAddress address;
private final int objectId;
public Caller(Class<T> type, SocketAddress address, int objectId) {
super(type);
this.address = address;
this.objectId = objectId;
}
@Override
protected Object invokeHashCode(Object[] args) throws Throwable {
return objectId;
}
@Override
protected Object invoke(Method method, Object[] args) throws Throwable {
final RmCallRequest request = new RmCallRequest();
final RmCallResponse[] response = new RmCallResponse[1];
String name = method.getName();
request.setOperationId(request.hashCode());
request.setObjectId(objectId);
request.setName(name);
request.setParameters(encodeParams(args));
RudpListener listener = new RudpListener() {
@Override
public void received(final RudpEvent event) {
workers.submit(new Runnable() {
@Override
public void run() {
try {
doReceive(event);
} catch (Exception ex) {
ex.printStackTrace();
}
}
});
}
private void doReceive(RudpEvent event) throws Exception {
SocketAddress a = event.getSocketAddress();
if (!a.equals(address)) {
return;
}
PropModel model = readModel(event);
if (!(model instanceof RmCallResponse)) {
return;
}
RmCallResponse r = (RmCallResponse )model;
if (r.getOperationId() != request.getOperationId()) {
return;
}
synchronized (response) {
response[0] = r;
response.notify();
}
}
};
server.addRudpListener(listener);
writeModel(address, request);
synchronized (response) {
if (response[0] == null) {
response.wait(timeout);
}
server.removeRudpListener(listener);
}
if (response[0] == null) {
throw new RemoteException("No response.");
} else if (response[0].isSucceeded()) {
Class<?>[] rtypes = new Class<?>[] { method.getReturnType() };
Object[] rvalues = new Object[] { response[0].getValue() };
return decodeParams(address, rtypes, rvalues)[0];
} else {
String message = response[0].getMessage();
throw new RemoteException(message);
}
}
}
public boolean publish(Class<?> type, String name, Remote remote)
throws InterruptedException {
synchronized (entries) {
RmEntry entry = getEntry(name);
if (entry != null) {
return false;
}
entry = new RmEntry(type, name, remote);
entries.add(entry);
return true;
}
}
public <T> T subscribe(final SocketAddress address, Class<T> type,
final String name) throws InterruptedException, RemoteException {
if (address == null) {
throw new IllegalArgumentException("Null address.");
}
if (name == null || name.isEmpty()) {
throw new IllegalArgumentException("Null or empty name.");
}
final RmSubscribeRequest request = new RmSubscribeRequest();
final RmSubscribeResponse[] response = new RmSubscribeResponse[1];
request.setOperationId(request.hashCode());
request.setName(name);
RudpListener listener = new RudpListener() {
@Override
public void received(final RudpEvent event) {
workers.submit(new Runnable() {
@Override
public void run() {
try {
doReceive(event);
} catch (Exception ex) {
ex.printStackTrace();
}
}
});
}
private void doReceive(RudpEvent event) throws Exception {
SocketAddress a = event.getSocketAddress();
if (!a.equals(address)) {
return;
}
PropModel model = readModel(event);
if (!(model instanceof RmSubscribeResponse)) {
return;
}
RmSubscribeResponse r = (RmSubscribeResponse )model;
if (r.getOperationId() != request.getOperationId()) {
return;
}
synchronized (response) {
response[0] = r;
response.notify();
}
}
};
server.addRudpListener(listener);
writeModel(address, request);
synchronized (response) {
if (response[0] == null) {
response.wait(timeout);
}
server.removeRudpListener(listener);
}
if (response[0] == null) {
throw new RemoteException("No response.");
} else if (response[0].isSucceeded()) {
int objectId = response[0].getObjectId();
return new Caller<T>(type, address, objectId).createProxy();
} else {
String message = response[0].getMessage();
throw new RemoteException(message);
}
}
public void register(Class<?> type, Remote remote) throws InterruptedException {
if (remote == null) {
throw new IllegalArgumentException("Null remote.");
}
synchronized (entries) {
int id = remote.hashCode();
RmEntry entry = getEntry(id);
if (entry == null) {
entry = new RmEntry(type, null, remote);
entries.add(entry);
} else {
entry.increment();
}
}
}
public void unregister(Remote remote) throws InterruptedException {
if (remote == null) {
throw new IllegalArgumentException("Null remote.");
}
synchronized (entries) {
int id = remote.hashCode();
RmEntry entry = getEntry(id);
if (entry != null) {
entry.decrement();
if (entry.getCount() == 0) {
entries.remove(entry);
}
}
}
}
public RmEntry getEntry(final String name) throws InterruptedException {
if (Util.isEmpty(name)) {
throw new IllegalArgumentException("Null or empty name.");
}
synchronized (entries) {
final Range[] ranges = Util.distributeWork(entries, 100, 5);
if (Util.isEmpty(ranges)) {
return null;
}
Set<Callable<RmEntry>> tasks = new HashSet<Callable<RmEntry>>();
for (final Range range : ranges) {
int from = range.getStart();
int to = range.getEnd();
final List<RmEntry> sublist = entries.subList(from, to);
tasks.add(new Callable<RmEntry>() {
@Override
public RmEntry call() throws Exception {
return searchEntry(name, sublist);
}
});
}
try {
return workers.invokeAny(tasks);
} catch (ExecutionException ex) {
return null;
}
}
}
public RmEntry getEntry(final int id) throws InterruptedException {
if (id == 0) {
throw new IllegalArgumentException("Id cannot be zero.");
}
synchronized (entries) {
final Range[] ranges = Util.distributeWork(entries, 100, 5);
if (Util.isEmpty(ranges)) {
return null;
}
Set<Callable<RmEntry>> tasks = new HashSet<Callable<RmEntry>>();
for (final Range range : ranges) {
int from = range.getStart();
int to = range.getEnd();
final List<RmEntry> sublist = entries.subList(from, to);
tasks.add(new Callable<RmEntry>() {
@Override
public RmEntry call() throws Exception {
return searchEntry(id, sublist);
}
});
}
try {
return workers.invokeAny(tasks);
} catch (ExecutionException ex) {
return null;
}
}
}
@Override
public void received(final RudpEvent event) {
workers.submit(new Runnable() {
@Override
public void run() {
try {
doReceive(event);
} catch (Exception ex) {
ex.printStackTrace();
}
}
});
}
public static RmContext getContext() {
return context.get();
}
private void doReceive(RudpEvent event) throws Exception {
SocketAddress address = event.getSocketAddress();
PropModel model = readModel(event);
processModel(address, model);
}
private RmEntry searchEntry(String name, List<RmEntry> entries)
throws Exception {
for (RmEntry entry : entries) {
if (entry.getName().equals(name)) {
return entry;
}
}
throw new Exception("Not found.");
}
private RmEntry searchEntry(int id, List<RmEntry> entries)
throws Exception {
for (RmEntry entry : entries) {
if (entry.getObjectId() == id) {
return entry;
}
}
throw new Exception("Not found.");
}
private PropModel readModel(RudpEvent event) throws Exception {
InputStream stream = event.getInputStream();
stream.mark(0);
PmInputStream pis = new PmInputStream(stream);
PropModel model = pis.readModel();
stream.reset();
pis.close();
return model;
}
private void processModel(SocketAddress address, PropModel model)
throws Exception {
if (model instanceof RmSubscribeRequest) {
processSubscribeRequest(address, (RmSubscribeRequest )model);
} else if (model instanceof RmCallRequest) {
processCallRequest(address, (RmCallRequest )model);
}
}
private void processSubscribeRequest(SocketAddress address,
RmSubscribeRequest request) throws Exception {
RmSubscribeResponse response = new RmSubscribeResponse();
response.setOperationId(request.getOperationId());
response.setMessage("Undefined error occurred.");
response.setResult(RmResponse.ERROR_UNDEFINED);
String name = request.getName();
RmEntry entry = getEntry(name);
if (entry == null) {
response.setMessage(MessageFormat.format(
"Remote named \"{0}\" not found.", name));
response.setResult(RmResponse.ERROR_NOT_FOUND);
} else {
int objectId = entry.getObjectId();
response.setObjectId(objectId);
response.setResult(RmResponse.SUCCESS);
response.setMessage(null);
}
writeModel(address, response);
}
private void processCallRequest(SocketAddress address,
RmCallRequest request) throws Exception {
RmCallResponse response = new RmCallResponse();
response.setOperationId(request.getOperationId());
response.setMessage("Undefined error occurred.");
response.setResult(RmResponse.ERROR_UNDEFINED);
int objectId = request.getObjectId();
RmEntry entry = getEntry(objectId);
if (entry == null) {
response.setMessage(MessageFormat.format(
"Remote object with id \"{0}\" not found.", objectId));
response.setResult(RmResponse.ERROR_NOT_FOUND);
} else {
String name = request.getName();
Object[] params = request.getParameters();
Remote remote = entry.getRemote();
Method method = getMethod(remote, name, params);
params = decodeParams(address, method.getParameterTypes(), params);
try {
context.set(new RmContext(this,
server.getSocketAddress(), address));
Object[] rvalues =
new Object[] { method.invoke(remote, params) };
context.set(null);
rvalues = encodeParams(rvalues);
response.setValue(rvalues[0]);
response.setMessage(null);
response.setResult(RmResponse.SUCCESS);
} catch (Exception ex) {
response.setMessage(ex.toString());
}
}
writeModel(address, response);
}
private void writeModel(final SocketAddress address,
final PropModel model) {
if (address == null) {
throw new IllegalArgumentException("Null address.");
}
if (model == null) {
throw new IllegalArgumentException("Null model.");
}
workers.submit(new Runnable() {
@Override
public void run() {
try {
doWriteModel(address, model);
} catch (Exception ex) {
ex.printStackTrace();
}
}
});
}
private void doWriteModel(SocketAddress address, PropModel model)
throws Exception {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
PmOutputStream pos = new PmOutputStream(baos);
pos.writeModel(model);
byte[] buffer = baos.toByteArray();
server.send(address, buffer, 0, buffer.length, true);
pos.close();
}
private Method getMethod(Remote remote, String name, Object[] args)
throws Exception {
if (args == null) {
args = new Object[0];
}
Class<?> type = remote.getClass();
Method[] methods = type.getMethods();
for (Method method : methods) {
String mname = method.getName();
Class<?>[] ptypes = method.getParameterTypes();
if (mname.equals(name) && ptypes.length == args.length) {
return method;
}
}
return null;
}
private Object[] encodeParams(Object[] params) {
if (params == null) {
params = new Object[0];
}
for (int i = 0; i < params.length; i++) {
if (params[i] instanceof Remote) {
params[i] = params[i].hashCode();
}
}
return params;
}
private Object[] decodeParams(SocketAddress address, Class<?>[] ptypes,
Object[] params) throws InterruptedException {
if (params == null) {
params = new Object[0];
}
for (int i = 0; i < params.length; i++) {
Class<?> ptype = ptypes[i];
if (Remote.class.isAssignableFrom(ptype)) {
int objectId = (Integer )params[i];
RmEntry entry = getEntry(objectId);
if (entry == null) {
params[i] = new Caller(ptype, address,
objectId).createProxy();
} else {
params[i] = entry.getRemote();
}
}
}
return params;
}
}
final class RmEntry {
private final Class<?> type;
private final String name;
private final Remote remote;
private int count;
public RmEntry(Class<?> type, String name, Remote remote) {
if (type == null) {
throw new IllegalArgumentException("Null type.");
}
if (name == null) {
name = "";
}
if (remote == null) {
throw new IllegalArgumentException("Null remote.");
}
this.type = type;
this.name = name;
this.remote = remote;
this.count = 1;
}
public String getName() {
return name;
}
public Class<?> getType() {
return type;
}
public Remote getRemote() {
return remote;
}
public int getObjectId() {
return remote.hashCode();
}
public int getCount() {
return count;
}
public int increment() {
return ++count;
}
public int decrement() {
return --count;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((name == null) ? 0 : name.hashCode());
result = prime * result + ((remote == null) ? 0 : remote.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
RmEntry other = (RmEntry) obj;
if (name == null) {
if (other.name != null)
return false;
} else if (!name.equals(other.name))
return false;
if (remote == null) {
if (other.remote != null)
return false;
} else if (!remote.equals(other.remote))
return false;
return true;
}
}