package de.jungblut.classification.tree;
import java.util.Random;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import de.jungblut.math.DoubleVector;
* Compilation unit for the object tree structure of the {@link DecisionTree}.
* It uses the ASM framework to calculate bytecode from the object structure
* under the proxy of a {@link AbstractTreeNode}. <br/>
* Nodes are compiled in the following manner:
* <ul>
* <li>Nominal nodes are directly compiled into switch statements, default case
* is a return of -1.</li>
* <li>Leaf nodes directly return their label.</li>
* <li>Numerical nodes are compiled into a single branch that compares the value
* of the split with a reference.</li>
* </ul>
* The data that needs to be stored in order to compare is put into the constant
* space of a class. Thus this is only limited to 2^16 entries on some JVMs. The
* created class is called "CompiledNode_{timestamp}_{random}" and it's
* {@link AbstractTreeNode} {@link #compileAndLoad(AbstractTreeNode)} method
* throws an unsupported operation exception.
* @author thomas.jungblut
public final class TreeCompiler implements Opcodes {
private static final String CLAZZ_NAME = "CompiledNode";
private static final Random RNG = new Random();
* Compiles the given node and directly loads it.
* @param name the name of the class. Possibly generated by
* {@link #generateClassName()}.
* @return a new compiled {@link AbstractTreeNode}.
public static AbstractTreeNode compileAndLoad(String name,
AbstractTreeNode node) throws Exception {
return load(name, compileNode(name, node));
* Loads the given tree node via its name and bytecode.
* @param name the name of the class.
* @param byteCode the byte code of the class.
* @return a new {@link AbstractTreeNode}.
public static AbstractTreeNode load(String name, byte[] byteCode)
throws Exception {
Class<AbstractTreeNode> loadClass = loadClass(byteCode, name);
return loadClass.newInstance();
* Compiles the given tree node and name into a class.
* @param name the name of the class.
* @param root the treenode to compile.
* @return a byte[] representing the class contents.
public static byte[] compileNode(String name, AbstractTreeNode root)
throws Exception {
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS);
cw.visit(Opcodes.V1_6, ACC_PUBLIC + ACC_FINAL, name, null,
Type.getInternalName(AbstractTreeNode.class), null);
cw.visitSource(name + ".java", null);
// add the constructor
MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "<init>", "()"
+ Type.VOID_TYPE.getDescriptor(), null, null);
mv.visitVarInsn(ALOAD, 0);
Type.getInternalName(AbstractTreeNode.class), "<init>", "()"
+ Type.VOID_TYPE.getDescriptor());
mv.visitMaxs(1, 1);
// override the "predict" method
MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "predict", Type
"predict", DoubleVector.class)), null, null);
Label end = new Label();
root.transformToByteCode(mv, end);
// return the last int on the stack.
mv.visitMaxs(0, 0);
// make the "transformToByteCode" throw an exception
MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "transformToByteCode", Type
"transformToByteCode", MethodVisitor.class, Label.class)), null,
mv.visitTypeInsn(NEW, "java/lang/UnsupportedOperationException");
"java/lang/UnsupportedOperationException", "<init>", "()V");
mv.visitMaxs(0, 0);
return cw.toByteArray();
* @return generate a pseudo-unique classname using the classname prefix and
* the timestamp in ms, because some collisions happened, a random
* string is appended as well.
public static synchronized String generateClassName() {
return CLAZZ_NAME + "_" + System.currentTimeMillis() + "_"
+ Integer.toString(Math.abs(RNG.nextInt()), 36);
private static <CLAZZ> Class<CLAZZ> loadClass(byte[] b, String className)
throws Exception {
// override classDefine (as it is protected) and define the class.
ClassLoader loader = ClassLoader.getSystemClassLoader();
Class<CLAZZ> cls = (Class<CLAZZ>) Class.forName("java.lang.ClassLoader");
java.lang.reflect.Method method = cls.getDeclaredMethod("defineClass",
new Class<?>[] { String.class, byte[].class, int.class, int.class });
try {
Class<CLAZZ> clz = (Class<CLAZZ>) Class.forName(className);
return clz;
} catch (Exception e) {
// swallow.
// protected method invocaton
try {
Object[] args = new Object[] { className, b, new Integer(0),
new Integer(b.length) };
return (Class<CLAZZ>) method.invoke(loader, args);
} finally {