/*
* Copyright (c) 1998-2011 Caucho Technology -- all rights reserved
*
* This file is part of Resin(R) Open Source
*
* Each copy or derived work must preserve the copyright notice and this
* notice unmodified.
*
* Resin Open Source is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* Resin Open Source is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE, or any warranty
* of NON-INFRINGEMENT. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License
* along with Resin Open Source; if not, write to the
*
* Free Software Foundation, Inc.
* 59 Temple Place, Suite 330
* Boston, MA 02111-1307 USA
*
* @author Emil Ong
*/
package com.caucho.soa.rest;
import com.caucho.server.util.CauchoSystem;
import com.caucho.soa.servlet.ProtocolServlet;
import com.caucho.vfs.Vfs;
import com.caucho.vfs.WriteStream;
import com.caucho.xml.stream.XMLStreamWriterImpl;
import javax.jws.WebMethod;
import javax.jws.WebParam;
import javax.jws.WebService;
import javax.servlet.GenericServlet;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.namespace.QName;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
/**
* A binding for REST services.
*/
public abstract class RestProtocolServlet extends GenericServlet
implements ProtocolServlet
{
private static final Logger log =
Logger.getLogger(RestProtocolServlet.class.getName());
public static final String DELETE = "DELETE";
public static final String GET = "GET";
public static final String HEAD = "HEAD";
public static final String POST = "POST";
public static final String PUT = "PUT";
private Map<String,Map<String,Method>> _methods
= new HashMap<String,Map<String,Method>>();
private HashMap<String,Method> _defaultMethods
= new HashMap<String,Method>();
protected Object _service;
public RestProtocolServlet()
{
}
public void setService(Object service)
{
_service = service;
}
public void init()
throws ServletException
{
try {
Class cl = _service.getClass();
if (cl.isAnnotationPresent(WebService.class)) {
WebService webService
= (WebService) cl.getAnnotation(WebService.class);
String endpoint = webService.endpointInterface();
if (endpoint != null && ! "".equals(endpoint))
cl = CauchoSystem.loadClass(webService.endpointInterface());
}
_methods.put(DELETE, new HashMap<String,Method>());
_methods.put(GET, new HashMap<String,Method>());
_methods.put(HEAD, new HashMap<String,Method>());
_methods.put(POST, new HashMap<String,Method>());
_methods.put(PUT, new HashMap<String,Method>());
for (Method method : cl.getMethods()) {
if (method.getDeclaringClass().equals(Object.class))
continue;
int modifiers = method.getModifiers();
// Allow abstract for interfaces
if (Modifier.isStatic(modifiers)
|| Modifier.isFinal(modifiers)
|| ! Modifier.isPublic(modifiers))
continue;
String methodName = method.getName();
if (method.isAnnotationPresent(WebMethod.class)) {
WebMethod webMethod =
(WebMethod) method.getAnnotation(WebMethod.class);
if (! "".equals(webMethod.operationName()))
methodName = webMethod.operationName();
}
if (method.isAnnotationPresent(RestMethod.class)) {
RestMethod restMethod =
(RestMethod) method.getAnnotation(RestMethod.class);
if (! "".equals(restMethod.operationName()))
methodName = restMethod.operationName();
}
boolean hasHTTPMethod = false;
if (method.isAnnotationPresent(Delete.class)) {
if (_methods.get(DELETE).containsKey(methodName)) {
throw new UnsupportedOperationException("Overloaded method: " +
method.getName());
}
_methods.get(DELETE).put(methodName, method);
hasHTTPMethod = true;
}
if (method.isAnnotationPresent(Get.class)) {
if (_methods.get(GET).containsKey(methodName)) {
throw new UnsupportedOperationException("Overloaded method: " +
method.getName());
}
_methods.get(GET).put(methodName, method);
hasHTTPMethod = true;
}
if (method.isAnnotationPresent(Post.class)) {
if (_methods.get(POST).containsKey(methodName)) {
throw new UnsupportedOperationException("Overloaded method: " +
method.getName());
}
_methods.get(POST).put(methodName, method);
hasHTTPMethod = true;
}
if (method.isAnnotationPresent(Put.class)) {
if (_methods.get(PUT).containsKey(methodName)) {
throw new UnsupportedOperationException("Overloaded method: " +
method.getName());
}
_methods.get(PUT).put(methodName, method);
hasHTTPMethod = true;
}
if (method.isAnnotationPresent(Head.class)) {
if (_methods.get(HEAD).containsKey(methodName)) {
throw new UnsupportedOperationException("Overloaded method: " +
method.getName());
}
_methods.get(HEAD).put(methodName, method);
hasHTTPMethod = true;
}
if (! hasHTTPMethod) {
if (_defaultMethods.containsKey(methodName)) {
throw new UnsupportedOperationException("Overloaded method: " +
method.getName());
}
_defaultMethods.put(methodName, method);
}
}
}
catch (Exception e) {
throw new ServletException(e);
}
}
public void service(ServletRequest request, ServletResponse response)
throws ServletException, IOException
{
HttpServletRequest req = (HttpServletRequest) request;
HttpServletResponse res = (HttpServletResponse) response;
Map<String,String> queryArguments = new HashMap<String,String>();
if (req.getQueryString() != null)
queryToMap(req.getQueryString(), queryArguments);
String[] pathArguments = null;
if (req.getPathInfo() != null) {
String pathInfo = req.getPathInfo();
// remove the initial and final slashes
int startPos = 0;
int endPos = pathInfo.length();
if (pathInfo.length() > 0 && pathInfo.charAt(0) == '/')
startPos = 1;
if (pathInfo.length() > startPos &&
pathInfo.charAt(pathInfo.length() - 1) == '/')
endPos = pathInfo.length() - 1;
pathInfo = pathInfo.substring(startPos, endPos);
pathArguments = pathInfo.split("/");
if (pathArguments.length == 1 && pathArguments[0].length() == 0)
pathArguments = new String[0];
}
else
pathArguments = new String[0];
try {
invoke(_service, req.getMethod(), pathArguments, queryArguments,
req, req.getInputStream(), res.getOutputStream());
}
catch (NoSuchMethodException e) {
res.sendError(HttpServletResponse.SC_BAD_REQUEST);
}
catch (Throwable e) {
throw new ServletException(e);
}
}
private static void queryToMap(String query,
Map<String,String> queryArguments)
{
String[] entries = query.split("&");
for (String entry : entries) {
if (entry.indexOf("=") < 0)
continue;
String[] nameValue = entry.split("=", 2);
queryArguments.put(nameValue[0], nameValue[1]);
}
}
private void invoke(Object object,
String httpMethod,
String[] pathArguments,
Map<String,String> queryArguments,
HttpServletRequest req,
InputStream postData,
OutputStream out)
throws Throwable
{
int pathIndex = 0;
boolean pathMethod = false;
// Two special approaches: path and query
//
// Path takes the first part of the path as the method name
//
// Query checks for /?method=myMethod in the query part
//
// Query overrides path since it's more explicit
String methodName = queryArguments.get("method");
if ((methodName == null) && (pathArguments.length > 0)) {
methodName = pathArguments[0];
if (methodName != null)
pathMethod = true;
}
// First, look by http method and method name
// This may hit the default method since methodName can be null
Method method = _methods.get(httpMethod).get(methodName);
// next, check for a default method, ignoring http method
if (method == null)
method = _defaultMethods.get(methodName);
// finally, check for a completely default method
if (method == null) {
method = _defaultMethods.get(null);
pathMethod = false;
}
if (method == null)
throw new NoSuchMethodException(methodName);
if (pathMethod)
pathIndex = 1;
// Construct the arguments for the invocation
ArrayList arguments = new ArrayList();
Class[] parameterTypes = method.getParameterTypes();
Annotation[][] annotations = method.getParameterAnnotations();
for (int i = 0; i < parameterTypes.length; i++) {
RestParam.Source source = RestParam.Source.QUERY;
String key = "arg" + i;
for (int j = 0; j < annotations[i].length; j++) {
if (annotations[i][j].annotationType().equals(RestParam.class)) {
RestParam restParam = (RestParam) annotations[i][j];
source = restParam.source();
}
else if (annotations[i][j].annotationType().equals(WebParam.class)) {
WebParam webParam = (WebParam) annotations[i][j];
if (! "".equals(webParam.name()))
key = webParam.name();
}
}
switch (source) {
case PATH:
{
String arg = null;
if (pathIndex < pathArguments.length)
arg = pathArguments[pathIndex++];
arguments.add(stringToType(parameterTypes[i], arg));
// XXX var args
}
break;
case QUERY:
arguments.add(stringToType(parameterTypes[i],
queryArguments.get(key)));
break;
case POST:
arguments.add(readPostData(postData));
break;
case HEADER:
arguments.add(stringToType(parameterTypes[i], req.getHeader(key)));
break;
}
}
Object result = method.invoke(object, arguments.toArray());
if (result != null)
writeResponse(out, result);
}
protected abstract Object readPostData(InputStream in)
throws IOException, RestException;
protected abstract void writeResponse(OutputStream out, Object obj)
throws IOException, RestException;
private static Object stringToType(Class type, String arg)
throws Throwable
{
if (arg == null) {
return null;
}
else if (type.equals(boolean.class)) {
return new Boolean(arg);
}
else if (type.equals(Boolean.class)) {
return new Boolean(arg);
}
else if (type.equals(byte.class)) {
return new Byte(arg);
}
else if (type.equals(Byte.class)) {
return new Byte(arg);
}
else if (type.equals(char.class)) {
if (arg.length() != 1) {
throw new IllegalArgumentException("Cannot convert String to type " +
type.getName());
}
return new Character(arg.charAt(0));
}
else if (type.equals(Character.class)) {
if (arg.length() != 1) {
throw new IllegalArgumentException("Cannot convert String to type " +
type.getName());
}
return new Character(arg.charAt(0));
}
else if (type.equals(double.class)) {
return new Double(arg);
}
else if (type.equals(Double.class)) {
return new Double(arg);
}
else if (type.equals(float.class)) {
return new Float(arg);
}
else if (type.equals(Float.class)) {
return new Float(arg);
}
else if (type.equals(int.class)) {
return new Integer(arg);
}
else if (type.equals(Integer.class)) {
return new Integer(arg);
}
else if (type.equals(long.class)) {
return new Long(arg);
}
else if (type.equals(Long.class)) {
return new Long(arg);
}
else if (type.equals(short.class)) {
return new Short(arg);
}
else if (type.equals(Short.class)) {
return new Short(arg);
}
else if (type.equals(String.class)) {
return arg;
}
else
throw new IllegalArgumentException("Cannot convert String to type " +
type.getName());
}
}