Package de.jungblut.classification.tree

Source Code of de.jungblut.classification.tree.TreeCompiler

package de.jungblut.classification.tree;

import java.util.Random;

import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;

import de.jungblut.math.DoubleVector;

/**
* Compilation unit for the object tree structure of the {@link DecisionTree}.
* It uses the ASM framework to calculate bytecode from the object structure
* under the proxy of a {@link AbstractTreeNode}. <br/>
* Nodes are compiled in the following manner:
* <ul>
* <li>Nominal nodes are directly compiled into switch statements, default case
* is a return of -1.</li>
* <li>Leaf nodes directly return their label.</li>
* <li>Numerical nodes are compiled into a single branch that compares the value
* of the split with a reference.</li>
* </ul>
*
* The data that needs to be stored in order to compare is put into the constant
* space of a class. Thus this is only limited to 2^16 entries on some JVMs. The
* created class is called "CompiledNode_{timestamp}_{random}" and it's
* {@link AbstractTreeNode} {@link #compileAndLoad(AbstractTreeNode)} method
* throws an unsupported operation exception.
*
* @author thomas.jungblut
*
*/
public final class TreeCompiler implements Opcodes {

  private static final String CLAZZ_NAME = "CompiledNode";
  private static final Random RNG = new Random();

  /**
   * Compiles the given node and directly loads it.
   *
   * @param name the name of the class. Possibly generated by
   *          {@link #generateClassName()}.
   * @return a new compiled {@link AbstractTreeNode}.
   */
  public static AbstractTreeNode compileAndLoad(String name,
      AbstractTreeNode node) throws Exception {
    return load(name, compileNode(name, node));
  }

  /**
   * Loads the given tree node via its name and bytecode.
   *
   * @param name the name of the class.
   * @param byteCode the byte code of the class.
   * @return a new {@link AbstractTreeNode}.
   */
  public static AbstractTreeNode load(String name, byte[] byteCode)
      throws Exception {
    Class<AbstractTreeNode> loadClass = loadClass(byteCode, name);
    return loadClass.newInstance();
  }

  /**
   * Compiles the given tree node and name into a class.
   *
   * @param name the name of the class.
   * @param root the treenode to compile.
   * @return a byte[] representing the class contents.
   */
  public static byte[] compileNode(String name, AbstractTreeNode root)
      throws Exception {
    ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS);
    cw.visit(Opcodes.V1_6, ACC_PUBLIC + ACC_FINAL, name, null,
        Type.getInternalName(AbstractTreeNode.class), null);
    cw.visitSource(name + ".java", null);

    // add the constructor
    {
      MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "<init>", "()"
          + Type.VOID_TYPE.getDescriptor(), null, null);
      mv.visitVarInsn(ALOAD, 0);
      mv.visitMethodInsn(INVOKESPECIAL,
          Type.getInternalName(AbstractTreeNode.class), "<init>", "()"
              + Type.VOID_TYPE.getDescriptor());
      mv.visitInsn(RETURN);
      mv.visitMaxs(1, 1);
      mv.visitEnd();
    }
    // override the "predict" method
    {
      MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "predict", Type
          .getMethodDescriptor(AbstractTreeNode.class.getDeclaredMethod(
              "predict", DoubleVector.class)), null, null);
      Label end = new Label();

      root.transformToByteCode(mv, end);

      mv.visitLabel(end);
      // return the last int on the stack.
      mv.visitInsn(Type.INT_TYPE.getOpcode(IRETURN));
      mv.visitMaxs(0, 0);
      mv.visitEnd();
    }
    // make the "transformToByteCode" throw an exception
    {
      MethodVisitor mv = cw.visitMethod(ACC_PUBLIC, "transformToByteCode", Type
          .getMethodDescriptor(AbstractTreeNode.class.getDeclaredMethod(
              "transformToByteCode", MethodVisitor.class, Label.class)), null,
          null);
      mv.visitTypeInsn(NEW, "java/lang/UnsupportedOperationException");
      mv.visitInsn(DUP);
      mv.visitMethodInsn(INVOKESPECIAL,
          "java/lang/UnsupportedOperationException", "<init>", "()V");
      mv.visitInsn(ATHROW);
      mv.visitMaxs(0, 0);
      mv.visitEnd();
    }
    cw.visitEnd();

    return cw.toByteArray();
  }

  /**
   * @return generate a pseudo-unique classname using the classname prefix and
   *         the timestamp in ms, because some collisions happened, a random
   *         string is appended as well.
   */
  public static synchronized String generateClassName() {
    return CLAZZ_NAME + "_" + System.currentTimeMillis() + "_"
        + Integer.toString(Math.abs(RNG.nextInt()), 36);
  }

  @SuppressWarnings("unchecked")
  private static <CLAZZ> Class<CLAZZ> loadClass(byte[] b, String className)
      throws Exception {
    // override classDefine (as it is protected) and define the class.
    ClassLoader loader = ClassLoader.getSystemClassLoader();
    Class<CLAZZ> cls = (Class<CLAZZ>) Class.forName("java.lang.ClassLoader");
    java.lang.reflect.Method method = cls.getDeclaredMethod("defineClass",
        new Class<?>[] { String.class, byte[].class, int.class, int.class });

    try {
      Class<CLAZZ> clz = (Class<CLAZZ>) Class.forName(className);
      return clz;
    } catch (Exception e) {
      // swallow.
    }

    // protected method invocaton
    method.setAccessible(true);
    try {
      Object[] args = new Object[] { className, b, new Integer(0),
          new Integer(b.length) };
      return (Class<CLAZZ>) method.invoke(loader, args);
    } finally {
      method.setAccessible(false);
    }
  }
}
TOP

Related Classes of de.jungblut.classification.tree.TreeCompiler

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.