// Copyright (c) 2009 Shardul Deo
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package com.googlecode.protobuf.socketrpc;
import java.util.HashMap;
import java.util.Map;
import com.google.protobuf.BlockingService;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.RpcCallback;
import com.google.protobuf.RpcController;
import com.google.protobuf.Service;
import com.google.protobuf.ServiceException;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.Descriptors.ServiceDescriptor;
import com.googlecode.protobuf.socketrpc.SocketRpcProtos.ErrorReason;
import com.googlecode.protobuf.socketrpc.SocketRpcProtos.Request;
import com.googlecode.protobuf.socketrpc.SocketRpcProtos.Response;
import com.googlecode.protobuf.socketrpc.SocketRpcProtos.Response.Builder;
/**
* Proxy that handles the RPC received by the server and forwards it to the
* appropriate service.
* <p>
* Both the {@link #doRpc(Request, RpcCallback)} and
* {@link #doBlockingRpc(Request)} methods try to find a matching
* {@link BlockingService} first and a matching {@link Service} second.
*
* @author Shardul Deo
*/
class RpcForwarder {
private final Map<String, Service> serviceMap =
new HashMap<String, Service>();
private final Map<String, BlockingService> blockingServiceMap =
new HashMap<String, BlockingService>();
/**
* Register an RPC service implementation to this forwarder.
*/
public void registerService(Service service) {
serviceMap.put(service.getDescriptorForType().getFullName(), service);
}
/**
* Register an RPC blocking service implementation to this forwarder.
*/
public void registerBlockingService(BlockingService service) {
blockingServiceMap.put(service.getDescriptorForType().getFullName(),
service);
}
/**
* Handle the blocking RPC request by forwarding it to the correct
* service/method.
*
* @throws RpcException If there was some error executing the RPC.
*/
public SocketRpcProtos.Response doBlockingRpc(
SocketRpcProtos.Request rpcRequest) throws RpcException {
// Get the service, first try BlockingService
BlockingService blockingService = blockingServiceMap.get(
rpcRequest.getServiceName());
if (blockingService != null) {
return forwardToBlockingService(rpcRequest, blockingService);
}
// Now try Service
Service service = serviceMap.get(rpcRequest.getServiceName());
if (service == null) {
throw new RpcException(ErrorReason.SERVICE_NOT_FOUND,
"Could not find service: " + rpcRequest.getServiceName(), null);
}
// Call service using an instant callback
Callback<Message> callback = new Callback<Message>();
SocketRpcController socketController = new SocketRpcController();
forwardToService(rpcRequest, callback, service, socketController);
// Build and return response (callback invocation is optional)
return createRpcResponse(callback.response, callback.invoked,
socketController);
}
/**
* Handle the the non-blocking RPC request by forwarding it to the correct
* service/method.
*
* @throws RpcException If there was some error executing the RPC.
*/
public void doRpc(SocketRpcProtos.Request rpcRequest,
final RpcCallback<SocketRpcProtos.Response> rpcCallback)
throws RpcException {
// Get the service, first try BlockingService
BlockingService blockingService = blockingServiceMap.get(
rpcRequest.getServiceName());
if (blockingService != null) {
Response response = forwardToBlockingService(rpcRequest, blockingService);
rpcCallback.run(response);
return;
}
// Now try Service
Service service = serviceMap.get(rpcRequest.getServiceName());
if (service == null) {
throw new RpcException(ErrorReason.SERVICE_NOT_FOUND,
"Could not find service: " + rpcRequest.getServiceName(), null);
}
// Call service using wrapper around rpcCallback
final SocketRpcController socketController = new SocketRpcController();
RpcCallback<Message> callback = new RpcCallback<Message>() {
@Override
public void run(Message response) {
rpcCallback.run(createRpcResponse(response, true, socketController));
}
};
forwardToService(rpcRequest, callback, service, socketController);
}
private Response forwardToBlockingService(Request rpcRequest,
BlockingService blockingService) throws RpcException {
// Get matching method
MethodDescriptor method = getMethod(rpcRequest,
blockingService.getDescriptorForType());
// Create request for method
Message request = getRequestProto(rpcRequest,
blockingService.getRequestPrototype(method));
// Call method
SocketRpcController socketController = new SocketRpcController();
try {
Message response = blockingService.callBlockingMethod(method,
socketController, request);
return createRpcResponse(response, true, socketController);
} catch (ServiceException e) {
throw new RpcException(ErrorReason.RPC_FAILED, e.getMessage(), e);
} catch (RuntimeException e) {
throw new RpcException(ErrorReason.RPC_ERROR,
"Error running method " + method.getFullName(), e);
}
}
private void forwardToService(SocketRpcProtos.Request rpcRequest,
RpcCallback<Message> callback, Service service,
RpcController socketController) throws RpcException {
// Get matching method
MethodDescriptor method = getMethod(rpcRequest,
service.getDescriptorForType());
// Create request for method
Message request = getRequestProto(rpcRequest,
service.getRequestPrototype(method));
// Call method
try {
service.callMethod(method, socketController, request, callback);
} catch (RuntimeException e) {
throw new RpcException(ErrorReason.RPC_ERROR,
"Error running method " + method.getFullName(), e);
}
}
/**
* Get matching method.
*/
private MethodDescriptor getMethod(SocketRpcProtos.Request rpcRequest,
ServiceDescriptor descriptor) throws RpcException {
MethodDescriptor method = descriptor.findMethodByName(
rpcRequest.getMethodName());
if (method == null) {
throw new RpcException(
ErrorReason.METHOD_NOT_FOUND,
String.format("Could not find method %s in service %s",
rpcRequest.getMethodName(), descriptor.getFullName()),
null);
}
return method;
}
/**
* Get request protobuf for the RPC method.
*/
private Message getRequestProto(SocketRpcProtos.Request rpcRequest,
Message requestPrototype) throws RpcException {
Message.Builder builder;
try {
builder = requestPrototype.newBuilderForType()
.mergeFrom(rpcRequest.getRequestProto());
if (!builder.isInitialized()) {
throw new RpcException(ErrorReason.BAD_REQUEST_PROTO,
"Invalid request proto", null);
}
} catch (InvalidProtocolBufferException e) {
throw new RpcException(ErrorReason.BAD_REQUEST_PROTO,
"Invalid request proto", e);
}
return builder.build();
}
/**
* Create RPC response protobuf from method invocation results.
*/
private SocketRpcProtos.Response createRpcResponse(Message response,
boolean callbackInvoked, SocketRpcController socketController) {
Builder responseBuilder = SocketRpcProtos.Response.newBuilder();
if (response != null) {
responseBuilder.setCallback(true).setResponseProto(
response.toByteString());
} else {
// Set whether callback was called (in case of async)
responseBuilder.setCallback(callbackInvoked);
}
if (socketController.failed()) {
responseBuilder.setError(socketController.errorText());
responseBuilder.setErrorReason(ErrorReason.RPC_FAILED);
}
return responseBuilder.build();
}
/**
* Callback that just saves the response and the fact that it was invoked.
*/
static class Callback<T extends Message> implements RpcCallback<T> {
private T response = null;
private boolean invoked = false;
@Override
public void run(T response) {
this.response = response;
invoked = true;
}
public T getResponse() {
return response;
}
public boolean isInvoked() {
return invoked;
}
}
/**
* Signifies error while handling RPC.
*/
static class RpcException extends Exception {
public final ErrorReason errorReason;
public final String msg;
public RpcException(ErrorReason errorReason, String msg, Exception cause) {
super(msg, cause);
this.errorReason = errorReason;
this.msg = msg;
}
}
}