/***********************************************************************************************************************
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
**********************************************************************************************************************/
/**
* This file is based on source code from the Hadoop Project (http://hadoop.apache.org/), licensed by the Apache
* Software Foundation (ASF) under the Apache License, Version 2.0. See the NOTICE file distributed with this work for
* additional information regarding copyright ownership.
*/
package eu.stratosphere.nephele.ipc;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.util.HashMap;
import java.util.Map;
import javax.net.SocketFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import eu.stratosphere.core.io.IOReadableWritable;
import eu.stratosphere.core.io.StringRecord;
import eu.stratosphere.core.protocols.VersionedProtocol;
import eu.stratosphere.nephele.net.NetUtils;
import eu.stratosphere.util.ClassUtils;
/**
* A simple RPC mechanism.
* A <i>protocol</i> is a Java interface. All parameters and return types must
* be one of:
* <ul>
* <li>a primitive type, <code>boolean</code>, <code>byte</code>, <code>char</code>, <code>short</code>,
* <code>int</code>, <code>long</code>, <code>float</code>, <code>double</code>, or <code>void</code>; or</li>
* <li>a {@link String}; or</li>
* <li>a {@link Writable}; or</li>
* <li>an array of the above types</li>
* </ul>
* All methods in the protocol should throw only IOException. No field data of
* the protocol instance is transmitted.
*/
public class RPC {
private static final Log LOG = LogFactory.getLog(RPC.class);
private RPC() {
} // no public ctor
/** A method invocation, including the method name and its parameters. */
private static class Invocation implements IOReadableWritable {
private String methodName;
private Class<? extends IOReadableWritable>[] parameterClasses;
private IOReadableWritable[] parameters;
@SuppressWarnings("unused")
public Invocation() {
}
// TODO: See if type safety can be improved here
@SuppressWarnings("unchecked")
public Invocation(Method method, IOReadableWritable[] parameters) {
this.methodName = method.getName();
this.parameterClasses = (Class<? extends IOReadableWritable>[]) method.getParameterTypes();
this.parameters = parameters;
}
/** The name of the method invoked. */
public String getMethodName() {
return methodName;
}
/** The parameter classes. */
public Class<? extends IOReadableWritable>[] getParameterClasses() {
return parameterClasses;
}
/** The parameter instances. */
public IOReadableWritable[] getParameters() {
return parameters;
}
// TODO: See if type safety can be improved here
@SuppressWarnings("unchecked")
public void read(DataInput in) throws IOException {
this.methodName = StringRecord.readString(in);
this.parameters = new IOReadableWritable[in.readInt()];
this.parameterClasses = new Class[parameters.length];
for (int i = 0; i < parameters.length; i++) {
// Read class name for parameter and try to get class to that name
final String className = StringRecord.readString(in);
try {
parameterClasses[i] = ClassUtils.getRecordByName(className);
} catch (ClassNotFoundException cnfe) {
throw new IOException(cnfe.toString());
}
// See if parameter is null
if (in.readBoolean()) {
try {
final String parameterClassName = StringRecord.readString(in);
final Class<? extends IOReadableWritable> parameterClass = ClassUtils
.getRecordByName(parameterClassName);
parameters[i] = parameterClass.newInstance();
} catch (IllegalAccessException iae) {
throw new IOException(iae.toString());
} catch (InstantiationException ie) {
throw new IOException(ie.toString());
} catch (ClassNotFoundException cnfe) {
throw new IOException(cnfe.toString());
}
// Object will do everything else on its own
parameters[i].read(in);
} else {
parameters[i] = null;
}
}
}
public void write(DataOutput out) throws IOException {
StringRecord.writeString(out, methodName);
out.writeInt(parameterClasses.length);
for (int i = 0; i < parameterClasses.length; i++) {
StringRecord.writeString(out, parameterClasses[i].getName());
if (parameters[i] == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
StringRecord.writeString(out, parameters[i].getClass().getName());
parameters[i].write(out);
}
}
}
public String toString() {
StringBuffer buffer = new StringBuffer();
buffer.append(methodName);
buffer.append("(");
for (int i = 0; i < parameters.length; i++) {
if (i != 0) {
buffer.append(", ");
}
buffer.append(parameters[i]);
}
buffer.append(")");
return buffer.toString();
}
}
/* Cache a client using its socket factory as the hash key */
static private class ClientCache {
private Map<SocketFactory, Client> clients = new HashMap<SocketFactory, Client>();
/**
* Construct & cache an IPC client with the user-provided SocketFactory
* if no cached client exists.
*
* @param conf
* Configuration
* @return an IPC client
*/
private synchronized Client getClient(SocketFactory factory) {
// Construct & cache client. The configuration is only used for timeout,
// and Clients have connection pools. So we can either (a) lose some
// connection pooling and leak sockets, or (b) use the same timeout for all
// configurations. Since the IPC is usually intended globally, not
// per-job, we choose (a).
Client client = clients.get(factory);
if (client == null) {
client = new Client(factory);
clients.put(factory, client);
} else {
client.incCount();
}
return client;
}
/**
* Stop a RPC client connection
* A RPC client is closed only when its reference count becomes zero.
*/
private void stopClient(Client client) {
synchronized (this) {
client.decCount();
if (client.isZeroReference()) {
clients.remove(client.getSocketFactory());
}
}
if (client.isZeroReference()) {
client.stop();
}
}
}
private static ClientCache CLIENTS = new ClientCache();
private static class Invoker implements InvocationHandler {
private InetSocketAddress address;
private Client client;
private boolean isClosed = false;
public Invoker(InetSocketAddress address, SocketFactory factory) {
this.address = address;
this.client = CLIENTS.getClient(factory);
}
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
// TODO clean up
IOReadableWritable[] castArgs = null;
if (args != null) {
castArgs = new IOReadableWritable[args.length];
// Check if args are instances of ReadableWritable
for (int i = 0; i < args.length; i++) {
if ((args[i] != null) && !(args[i] instanceof IOReadableWritable)) {
throw new IOException("Argument " + i + " of method " + method.getName()
+ " is not of type IOReadableWriteable");
} else {
castArgs[i] = (IOReadableWritable) args[i];
}
}
}
final IOReadableWritable value = this.client.call(new Invocation(method, castArgs), this.address, method
.getDeclaringClass());
return value;
}
/* close the IPC client that's responsible for this invoker's RPCs */
synchronized private void close() {
if (!this.isClosed) {
this.isClosed = true;
CLIENTS.stopClient(this.client);
}
}
}
public static VersionedProtocol waitForProxy(Class<? extends VersionedProtocol> protocol, InetSocketAddress addr)
throws IOException {
return waitForProxy(protocol, addr, Long.MAX_VALUE);
}
/**
* Get a proxy connection to a remote server
*
* @param protocol
* protocol class
* @param addr
* remote address
* @param timeout
* time in milliseconds before giving up
* @return the proxy
* @throws IOException
* if the far end through a RemoteException
*/
static <V extends VersionedProtocol> V waitForProxy(Class<V> protocol, InetSocketAddress addr,
long timeout) throws IOException {
long startTime = System.currentTimeMillis();
IOException ioe;
while (true) {
try {
return getProxy(protocol, addr);
} catch (ConnectException se) { // namenode has not been started
LOG.info("Server at " + addr + " not available yet, Zzzzz...");
ioe = se;
} catch (SocketTimeoutException te) { // namenode is busy
LOG.info("Problem connecting to server: " + addr);
ioe = te;
}
// check if timed out
if (System.currentTimeMillis() - timeout >= startTime) {
throw ioe;
}
// wait for retry
try {
Thread.sleep(1000);
} catch (InterruptedException ie) {
// IGNORE
}
}
}
/**
* Construct a client-side proxy object that implements the named protocol,
* talking to a server at the named address.
*/
public static <V extends VersionedProtocol> V getProxy(Class<V> protocol, InetSocketAddress addr,
SocketFactory factory) throws IOException {
@SuppressWarnings("unchecked")
V proxy = (V) Proxy.newProxyInstance(protocol.getClassLoader(), new Class[] { protocol }, new Invoker(addr, factory));
return proxy;
}
/**
* Construct a client-side proxy object with the default SocketFactory
*
* @param protocol
* @param addr
* @return
* @throws IOException
*/
public static <V extends VersionedProtocol> V getProxy(Class<V> protocol, InetSocketAddress addr)
throws IOException {
return getProxy(protocol, addr, NetUtils.getDefaultSocketFactory());
}
/**
* Stop this proxy and release its invoker's resource
*
* @param proxy
* the proxy to be stopped
*/
public static void stopProxy(VersionedProtocol proxy) {
if (proxy != null) {
((Invoker) Proxy.getInvocationHandler(proxy)).close();
}
}
/**
* Construct a server for a protocol implementation instance listening on a
* port and address.
*/
public static Server getServer(final Object instance, final String bindAddress, final int port,
final int numHandlers) throws IOException {
return new Server(instance, bindAddress, port, numHandlers);
}
/** An RPC Server. */
public static class Server extends eu.stratosphere.nephele.ipc.Server {
private Object instance;
/**
* Construct an RPC server.
*
* @param instance
* the instance whose methods will be called
* @param conf
* the configuration to use
* @param bindAddress
* the address to bind on to listen for connection
* @param port
* the port to listen for connections on
*/
public Server(Object instance, String bindAddress, int port)
throws IOException {
this(instance, bindAddress, port, 1);
}
private static String classNameBase(String className) {
String[] names = className.split("\\.", -1);
if (names == null || names.length == 0) {
return className;
}
return names[names.length - 1];
}
/**
* Construct an RPC server.
*
* @param instance
* the instance whose methods will be called
* @param conf
* the configuration to use
* @param bindAddress
* the address to bind on to listen for connection
* @param port
* the port to listen for connections on
* @param numHandlers
* the number of method handler threads to run
*/
public Server(Object instance, String bindAddress, int port, int numHandlers) throws IOException {
super(bindAddress, port, Invocation.class, numHandlers, classNameBase(instance.getClass().getName()));
this.instance = instance;
}
public IOReadableWritable call(Class<?> protocol, IOReadableWritable param, long receivedTime)
throws IOException {
try {
final Invocation call = (Invocation) param;
final Method method = protocol.getMethod(call.getMethodName(), call.getParameterClasses());
method.setAccessible(true);
final Object value = method.invoke((Object) instance, (Object[]) call.getParameters());
return (IOReadableWritable) value;
} catch (InvocationTargetException e) {
final Throwable target = e.getTargetException();
if (target instanceof IOException) {
throw (IOException) target;
} else {
final IOException ioe = new IOException(target.toString());
ioe.setStackTrace(target.getStackTrace());
throw ioe;
}
} catch (Throwable e) {
final IOException ioe = new IOException(e.toString());
ioe.setStackTrace(e.getStackTrace());
throw ioe;
}
}
}
}