package org.gap.jseed;
import static org.gap.jseed.JavaWriter.braces;
import static org.gap.jseed.JavaWriter.call;
import static org.gap.jseed.JavaWriter.createParameters;
import static org.gap.jseed.JavaWriter.getInvokingMethodCode;
import static org.gap.jseed.JavaWriter.getParameterTypes;
import static org.gap.jseed.JavaWriter.line;
import static org.gap.jseed.JavaWriter.returnCall;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationHandler;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import javassist.CannotCompileException;
import javassist.ClassClassPath;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.NotFoundException;
import org.gap.jseed.util.ClassUtil;
public class ClassInjector extends AbstractInjector {
private static final String INVOKE_METHOD = "invoke";
private static final String HASH_CODE = "hashCode";
private static final String TO_STRING = "toString";
private static final String EQUALS = "equals";
private static final String HASH_CODE_SIGNATURE = "public int hashCode();";
private static final String EQUALS_SIGNATURE = "public boolean equals(Object obj);";
private static final String TO_STRING_SIGNATURE = "public String toString();";
private Map<Class<? extends Annotation>, Class<? extends InvocationHandler>> annotations;
private Map<Class<? extends Annotation>, Validator> validators;
private static final String INVOCATION_FIELD = "_invocation";
public ClassInjector() {
annotations = new HashMap<Class<? extends Annotation>, Class<? extends InvocationHandler>>();
validators = new HashMap<Class<? extends Annotation>, Validator>();
}
public void injectBehavior(Class<?> theInterface, CtClass type, CtClass implementation) throws ClassNotFoundException, CannotCompileException, NotFoundException {
Collection<CtClass> allClasses = getClassHeirarchy(type);
if(isClassAnnotated(theInterface, implementation, allClasses)) {
injectClassBehavior(implementation, allClasses);
} else {
injectDefaultClassBehavior(implementation, allClasses);
}
}
private boolean isClassAnnotated(Class<?> theInterface,
CtClass implementation, Collection<CtClass> allClasses)
throws ClassNotFoundException, CannotCompileException,
NotFoundException {
for (CtClass ctClass : allClasses) {
for (Object each : ctClass.getAnnotations()) {
if (isAnnotationRegistered(theInterface, implementation, each)) {
return true;
}
System.err.println(each + ": not found, on class " + theInterface);
}
}
return false;
}
private boolean isAnnotationRegistered(Class<?> theInterface, CtClass implementation, Object each)
throws CannotCompileException, NotFoundException {
for (Class<? extends Annotation> eachAnnotation : annotations.keySet()) {
if (eachAnnotation.isInstance(each)) {
performValidationOnClassWith(theInterface, eachAnnotation);
injectField(INVOCATION_FIELD, implementation, annotations.get(eachAnnotation));
return true;
}
}
return false;
}
private void performValidationOnClassWith(Class<?> theInterface, Class<? extends Annotation> eachAnnotation) {
if (validators.containsKey(eachAnnotation)) {
validators.get(eachAnnotation).validate(theInterface);
}
}
public void injectBehavior(Class<? extends Annotation> annotation, Class<?> theInterface, CtClass type, CtClass implementation) throws CannotCompileException, NotFoundException {
injectField(INVOCATION_FIELD, implementation, annotations.get(annotation));
injectClassBehavior(implementation, getClassHeirarchy(type));
}
private void injectClassBehavior(CtClass implementation, Collection<CtClass> allClasses) throws NotFoundException,
CannotCompileException {
for (CtClass eachInterface : allClasses) {
for (CtMethod each : eachInterface.getDeclaredMethods()) {
if (isMethodDefinedFor(implementation, each)) {
insertCallInvocationHandler(implementation, each);
}
}
injectObjectMethod(implementation, EQUALS, EQUALS_SIGNATURE);
injectObjectMethod(implementation, TO_STRING, TO_STRING_SIGNATURE);
injectObjectMethod(implementation, HASH_CODE, HASH_CODE_SIGNATURE);
}
}
private void injectDefaultClassBehavior(CtClass implementation,
Collection<CtClass> allClasses) throws CannotCompileException, NotFoundException {
for (CtClass eachInterface : allClasses) {
for (CtMethod each : eachInterface.getDeclaredMethods()) {
if (isMethodDefinedFor(implementation, each)) {
injectDefaultMethod(implementation, each);
}
}
}
}
private boolean isMethodDefinedFor(CtClass implementation, CtMethod each) {
try {
CtMethod method = implementation.getMethod(each.getName(), each.getSignature());
return !ClassUtil.isMethodAbstract(method);
} catch (Exception e) {}
return false;
}
private Collection<CtClass> getClassHeirarchy(CtClass type) {
LinkedList<CtClass> result = new LinkedList<CtClass>();
Collection<CtClass> extendedClasses = getExtendedClasses(type);
extendedClasses.add(type);
for (CtClass eachAbstract : extendedClasses) {
result.add(eachAbstract);
Collections.addAll(result, getInterfacesFor(eachAbstract));
}
return result;
}
private CtClass[] getInterfacesFor(
CtClass implementingClass) {
if (ClassUtil.hasInterfaces(implementingClass)) {
return ClassUtil.getInterfaces(implementingClass);
}
return new CtClass[]{};
}
private Collection<CtClass> getExtendedClasses(CtClass current) {
LinkedList<CtClass> result = new LinkedList<CtClass>();
while (ClassUtil.hasAbstractClass(current)) {
current = ClassUtil.getSuperClass(current);
result.add(current);
}
return result;
}
private void injectObjectMethod(CtClass implementation, String methodName, String signature)
throws CannotCompileException, NotFoundException {
if (!ClassUtil.isMethodDefinedOn(methodName, implementation)) {
CtMethod equals = CtMethod.make(signature, implementation);
insertCallInvocationHandler(implementation, equals);
}
}
private void insertCallInvocationHandler(CtClass implementation, CtMethod each) throws NotFoundException,
CannotCompileException {
String parameterTypes = getParameterTypes(each.getParameterTypes());
StringBuffer methodBody = new StringBuffer();
methodBody.append(getInvokingMethodCode(each, parameterTypes));
methodBody.append(writeInvocationHandlerCallMethodBody(each));
injectMethod(implementation, each, braces(methodBody.toString()));
}
private String writeInvocationHandlerCallMethodBody(CtMethod each)
throws NotFoundException {
return returnCall(each,
line(
call(INVOKE_METHOD,
INVOCATION_FIELD,
"this",
"method",
createParameters(each.getParameterTypes().length)
)
)
);
}
public void add(Class<? extends Annotation> annotation,
Class<? extends InvocationHandler> handler) {
annotations.put(annotation, handler);
ClassPool.getDefault().insertClassPath(new ClassClassPath(handler));
}
public void validateWith(Class<? extends Annotation> annotation,
Validator validator) {
validators.put(annotation, validator);
}
}