package yalp.classloading.enhancers;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.util.Stack;
import javassist.CannotCompileException;
import javassist.CtClass;
import javassist.CtField;
import javassist.CtMethod;
import javassist.Modifier;
import javassist.NotFoundException;
import javassist.bytecode.annotation.Annotation;
import javassist.expr.ExprEditor;
import javassist.expr.FieldAccess;
import javassist.expr.Handler;
import yalp.Logger;
import yalp.classloading.ApplicationClasses.ApplicationClass;
import yalp.exceptions.UnexpectedException;
/**
* Enhance controllers classes.
*/
public class ControllersEnhancer extends Enhancer {
public static ThreadLocal<Stack<String>> currentAction = new ThreadLocal<Stack<String>>();
@Override
public void enhanceThisClass(final ApplicationClass applicationClass) throws Exception {
if (isAnon(applicationClass)) {
return;
}
CtClass ctClass = makeClass(applicationClass);
if (!ctClass.subtypeOf(classPool.get(ControllerSupport.class.getName()))) {
return;
}
for (final CtMethod ctMethod : ctClass.getDeclaredMethods()) {
// Threaded access
ctMethod.instrument(new ExprEditor() {
@Override
public void edit(FieldAccess fieldAccess) throws CannotCompileException {
try {
if (isThreadedFieldAccess(fieldAccess.getField())) {
if (fieldAccess.isReader()) {
fieldAccess.replace("$_ = ($r)yalp.utils.Java.invokeStatic($type, \"current\");");
}
}
} catch (Exception e) {
Logger.error(e, "Error in ControllersEnhancer. %s.%s has not been properly enhanced (fieldAccess %s).", applicationClass.name, ctMethod.getName(), fieldAccess);
throw new UnexpectedException(e);
}
}
});
// Auto-redirect
boolean isHandler = false;
for (Annotation a : getAnnotations(ctMethod).getAnnotations()) {
if (a.getTypeName().startsWith("yalp.mvc.")) {
isHandler = true;
break;
}
if (a.getTypeName().endsWith("$ByPass")) {
isHandler = true;
break;
}
}
// Perhaps it is a scala-generated method ?
if (ctMethod.getName().contains("$")) {
isHandler = true;
} else {
if (ctClass.getName().endsWith("$") && ctMethod.getParameterTypes().length == 0) {
try {
ctClass.getField(ctMethod.getName());
isHandler = true;
} catch (NotFoundException e) {
// ok
}
}
}
if (isScalaObject(ctClass)) {
// Auto reverse -->
if (Modifier.isPublic(ctMethod.getModifiers()) && ((ctClass.getName().endsWith("$") && !ctMethod.getName().contains("$default$"))) && !isHandler) {
try {
ctMethod.insertBefore(
"if(yalp.mvc.Controller._currentReverse.get() != null) {"
+ "yalp.mvc.Controller.redirect(\"" + ctClass.getName().replace("$", "") + "." + ctMethod.getName() + "\", $args);"
+ generateValidReturnStatement(ctMethod.getReturnType())
+ "}");
ctMethod.insertBefore(
"((java.util.Stack)yalp.classloading.enhancers.ControllersEnhancer.currentAction.get()).push(\"" + ctClass.getName().replace("$", "") + "." + ctMethod.getName() + "\");");
ctMethod.insertAfter(
"((java.util.Stack)yalp.classloading.enhancers.ControllersEnhancer.currentAction.get()).pop();", true);
} catch (Exception e) {
Logger.error(e, "Error in ControllersEnhancer. %s.%s has not been properly enhanced (auto-reverse).", applicationClass.name, ctMethod.getName());
throw new UnexpectedException(e);
}
}
} else {
// Auto redirect -->
if (Modifier.isPublic(ctMethod.getModifiers()) && Modifier.isStatic(ctMethod.getModifiers()) && ctMethod.getReturnType().equals(CtClass.voidType) && !isHandler) {
try {
ctMethod.insertBefore(
"if(!yalp.classloading.enhancers.ControllersEnhancer.ControllerInstrumentation.isActionCallAllowed()) {"
+ "yalp.mvc.Controller.redirect(\"" + ctClass.getName().replace("$", "") + "." + ctMethod.getName() + "\", $args);"
+ generateValidReturnStatement(ctMethod.getReturnType()) + "}"
+ "yalp.classloading.enhancers.ControllersEnhancer.ControllerInstrumentation.stopActionCall();");
} catch (Exception e) {
Logger.error(e, "Error in ControllersEnhancer. %s.%s has not been properly enhanced (auto-redirect).", applicationClass.name, ctMethod.getName());
throw new UnexpectedException(e);
}
}
}
// Enhance global catch to avoid potential unwanted catching of yalp.mvc.results.Result
ctMethod.instrument(new ExprEditor() {
@Override
public void edit(Handler handler) throws CannotCompileException {
StringBuilder code = new StringBuilder();
try {
code.append("if($1 instanceof yalp.mvc.results.Result || $1 instanceof yalp.Invoker.Suspend) throw $1;");
handler.insertBefore(code.toString());
} catch (NullPointerException e) {
// TODO: finally clause ?
// There are no $1 in finally statements in javassist
}
}
});
}
// Done.
applicationClass.enhancedByteCode = ctClass.toBytecode();
ctClass.defrost();
}
/**
* Mark class that need controller enhancement
*/
public static interface ControllerSupport {
}
/**
* Check if a field must be translated to a 'thread safe field'
*/
static boolean isThreadedFieldAccess(CtField field) {
if (field.getDeclaringClass().getName().equals("yalp.mvc.Controller") || field.getDeclaringClass().getName().equals("yalp.mvc.WebSocketController")) {
return field.getName().equals("params")
|| field.getName().equals("request")
|| field.getName().equals("response")
|| field.getName().equals("session")
|| field.getName().equals("params")
|| field.getName().equals("renderArgs")
|| field.getName().equals("routeArgs")
|| field.getName().equals("validation")
|| field.getName().equals("inbound")
|| field.getName().equals("outbound")
|| field.getName().equals("flash");
}
return false;
}
/**
* Runtime part needed by the instrumentation
*/
public static class ControllerInstrumentation {
public static boolean isActionCallAllowed() {
return allow.get();
}
public static void initActionCall() {
allow.set(true);
}
public static void stopActionCall() {
allow.set(false);
}
static ThreadLocal<Boolean> allow = new ThreadLocal<Boolean>();
}
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface ByPass {
}
static String generateValidReturnStatement(CtClass type) {
if (type.equals(CtClass.voidType)) {
return "return;";
}
if (type.equals(CtClass.booleanType)) {
return "return false;";
}
if (type.equals(CtClass.charType)) {
return "return '';";
}
if (type.equals(CtClass.byteType)) {
return "return (byte)0;";
}
if (type.equals(CtClass.doubleType)) {
return "return (double)0;";
}
if (type.equals(CtClass.floatType)) {
return "return (float)0;";
}
if (type.equals(CtClass.intType)) {
return "return (int)0;";
}
if (type.equals(CtClass.longType)) {
return "return (long)0;";
}
if (type.equals(CtClass.shortType)) {
return "return (short)0;";
}
return "return null;";
}
}