Package org.drools.base

Source Code of org.drools.base.ClassFieldExtractorFactory$ByteArrayClassLoader

package org.drools.base;

/*
* Copyright 2005 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import java.lang.reflect.Method;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.ProtectionDomain;
import java.util.HashMap;
import java.util.Map;

import org.drools.RuntimeDroolsException;
import org.drools.asm.ClassWriter;
import org.drools.asm.Label;
import org.drools.asm.MethodVisitor;
import org.drools.asm.Opcodes;
import org.drools.asm.Type;
import org.drools.base.extractors.BaseBooleanClassFieldExtractor;
import org.drools.base.extractors.BaseByteClassFieldExtractor;
import org.drools.base.extractors.BaseCharClassFieldExtractor;
import org.drools.base.extractors.BaseDoubleClassFieldExtractor;
import org.drools.base.extractors.BaseFloatClassFieldExtractor;
import org.drools.base.extractors.BaseIntClassFieldExtractor;
import org.drools.base.extractors.BaseLongClassFieldExtractors;
import org.drools.base.extractors.BaseObjectClassFieldExtractor;
import org.drools.base.extractors.BaseShortClassFieldExtractor;
import org.drools.base.extractors.MVELClassFieldExtractor;
import org.drools.base.extractors.SelfReferenceClassFieldExtractor;
import org.drools.common.InternalWorkingMemory;
import org.drools.util.asm.ClassFieldInspector;

/**
* This generates subclasses of BaseClassFieldExtractor to provide field extractors.
* This should not be used directly, but via ClassFieldExtractor (which ensures that it is
* all nicely serializable).
*
* @author Alexander Bagerman
* @author Michael Neale
*/

public class ClassFieldExtractorFactory {

    private static final String           BASE_PACKAGE         = "org/drools/base";

    private static final String           SELF_REFERENCE_FIELD = "this";

    private static final ProtectionDomain PROTECTION_DOMAIN;

    private static final Map              inspectors           = new HashMap();
   
    private static ByteArrayClassLoader   byteArrayClassLoader;

    static {
        PROTECTION_DOMAIN = (ProtectionDomain) AccessController.doPrivileged( new PrivilegedAction() {
            public Object run() {
                return ClassFieldExtractorFactory.class.getProtectionDomain();
            }
        } );
    }

    public static BaseClassFieldExtractor getClassFieldExtractor(final Class clazz,
                                                                 final String fieldName,
                                                                 final ClassLoader classLoader) {
        if ( byteArrayClassLoader == null ) {
            if (classLoader == null ) {
                throw new RuntimeDroolsException("ClassFieldExtractorFactory cannot have a null parent ClassLoader" );
            }           
            byteArrayClassLoader = new ByteArrayClassLoader( classLoader );           
        }
        try {
            // if it is a self reference
            if ( SELF_REFERENCE_FIELD.equals( fieldName ) ) {
                // then just create an instance of the special class field extractor
                return new SelfReferenceClassFieldExtractor( clazz,
                                                             fieldName );
            } else if( fieldName.indexOf( '.' ) > -1 || fieldName.indexOf( '[' ) > -1 ) {
                // we need MVEL extractor for expressions
                return new MVELClassFieldExtractor( clazz,
                                                    fieldName,
                                                    classLoader );
            } else {
                // otherwise, bytecode generate a specific extractor
                ClassFieldInspector inspector = (ClassFieldInspector) inspectors.get( clazz );
                if ( inspector == null ) {
                    inspector = new ClassFieldInspector( clazz );
                    inspectors.put( clazz,
                                    inspector );
                }
                final Class fieldType = (Class) inspector.getFieldTypes().get( fieldName );
                final Method getterMethod = (Method) inspector.getGetterMethods().get( fieldName );
                if( fieldType != null && getterMethod != null ) {
                    final String className = ClassFieldExtractorFactory.BASE_PACKAGE + "/" + Type.getInternalName( clazz ) + "$" + getterMethod.getName();

                    // generating byte array to create target class
                    final byte[] bytes = dump( clazz,
                                               className,
                                               getterMethod,
                                               fieldType,
                                               clazz.isInterface() );
                    // use bytes to get a class

                    final Class newClass = byteArrayClassLoader.defineClass( className.replace( '/',
                                                                                                '.' ),
                                                                             bytes,
                                                                             PROTECTION_DOMAIN );
                    // instantiating target class
                    final Integer index = (Integer) inspector.getFieldNames().get( fieldName );
                    final ValueType valueType = ValueType.determineValueType( fieldType );
                    final Object[] params = {index, fieldType, valueType};
                    return (BaseClassFieldExtractor) newClass.getConstructors()[0].newInstance( params );
                } else {
                    throw new RuntimeDroolsException("Field/method '"+fieldName+"' not found for class '"+clazz.getName()+"'" );
                }
            }
        } catch ( final RuntimeDroolsException e ) {
            throw e;
        } catch ( final Exception e ) {
            throw new RuntimeDroolsException( e );
        }
    }

    private static byte[] dump(final Class originalClass,
                               final String className,
                               final Method getterMethod,
                               final Class fieldType,
                               final boolean isInterface) throws Exception {

        final ClassWriter cw = new ClassWriter( true );

        final Class superClass = getSuperClassFor( fieldType );
        buildClassHeader( superClass,
                          className,
                          cw );

        //        buildConstructor( superClass,
        //                          className,
        //                          cw );

        build3ArgConstructor( superClass,
                              className,
                              cw );

        buildGetMethod( originalClass,
                        className,
                        superClass,
                        getterMethod,
                        cw );

        cw.visitEnd();

        return cw.toByteArray();
    }

    /**
     * Builds the class header
     * 
     * @param clazz The class to build the extractor for
     * @param className The extractor class name
     * @param cw
     */
    protected static void buildClassHeader(final Class superClass,
                                           final String className,
                                           final ClassWriter cw) {
        cw.visit( Opcodes.V1_2,
                  Opcodes.ACC_PUBLIC + Opcodes.ACC_SUPER,
                  className,
                  null,
                  Type.getInternalName( superClass ),
                  null );

        cw.visitSource( null,
                        null );
    }

    //    /**
    //     * Creates a constructor for the field extractor receiving
    //     * the class instance and field name
    //     *
    //     * @param originalClassName
    //     * @param className
    //     * @param cw
    //     */
    //    private static void buildConstructor(final Class superClazz,
    //                                         final String className,
    //                                         final ClassWriter cw) {
    //        MethodVisitor mv;
    //        {
    //            mv = cw.visitMethod( Opcodes.ACC_PUBLIC,
    //                                 "<init>",
    //                                 Type.getMethodDescriptor( Type.VOID_TYPE,
    //                                                           new Type[]{Type.getType( Class.class ), Type.getType( String.class )} ),
    //                                 null,
    //                                 null );
    //            mv.visitCode();
    //            final Label l0 = new Label();
    //            mv.visitLabel( l0 );
    //            mv.visitVarInsn( Opcodes.ALOAD,
    //                             0 );
    //            mv.visitVarInsn( Opcodes.ALOAD,
    //                             1 );
    //            mv.visitVarInsn( Opcodes.ALOAD,
    //                             2 );
    //            mv.visitMethodInsn( Opcodes.INVOKESPECIAL,
    //                                Type.getInternalName( superClazz ),
    //                                "<init>",
    //                                Type.getMethodDescriptor( Type.VOID_TYPE,
    //                                                          new Type[]{Type.getType( Class.class ), Type.getType( String.class )} ) );
    //            final Label l1 = new Label();
    //            mv.visitLabel( l1 );
    //            mv.visitInsn( Opcodes.RETURN );
    //            final Label l2 = new Label();
    //            mv.visitLabel( l2 );
    //            mv.visitLocalVariable( "this",
    //                                   "L" + className + ";",
    //                                   null,
    //                                   l0,
    //                                   l2,
    //                                   0 );
    //            mv.visitLocalVariable( "clazz",
    //                                   Type.getDescriptor( Class.class ),
    //                                   null,
    //                                   l0,
    //                                   l2,
    //                                   1 );
    //            mv.visitLocalVariable( "fieldName",
    //                                   Type.getDescriptor( String.class ),
    //                                   null,
    //                                   l0,
    //                                   l2,
    //                                   2 );
    //            mv.visitMaxs( 0,
    //                          0 );
    //            mv.visitEnd();
    //        }
    //    }

    /**
     * Creates a constructor for the field extractor receiving
     * the index, field type and value type
     *
     * @param originalClassName
     * @param className
     * @param cw
     */
    private static void build3ArgConstructor(final Class superClazz,
                                             final String className,
                                             final ClassWriter cw) {
        MethodVisitor mv;
        {
            mv = cw.visitMethod( Opcodes.ACC_PUBLIC,
                                 "<init>",
                                 Type.getMethodDescriptor( Type.VOID_TYPE,
                                                           new Type[]{Type.getType( int.class ), Type.getType( Class.class ), Type.getType( ValueType.class )} ),
                                 null,
                                 null );
            mv.visitCode();
            final Label l0 = new Label();
            mv.visitLabel( l0 );
            mv.visitVarInsn( Opcodes.ALOAD,
                             0 );
            mv.visitVarInsn( Opcodes.ILOAD,
                             1 );
            mv.visitVarInsn( Opcodes.ALOAD,
                             2 );
            mv.visitVarInsn( Opcodes.ALOAD,
                             3 );
            mv.visitMethodInsn( Opcodes.INVOKESPECIAL,
                                Type.getInternalName( superClazz ),
                                "<init>",
                                Type.getMethodDescriptor( Type.VOID_TYPE,
                                                          new Type[]{Type.getType( int.class ), Type.getType( Class.class ), Type.getType( ValueType.class )} ) );
            final Label l1 = new Label();
            mv.visitLabel( l1 );
            mv.visitInsn( Opcodes.RETURN );
            final Label l2 = new Label();
            mv.visitLabel( l2 );
            mv.visitLocalVariable( "this",
                                   "L" + className + ";",
                                   null,
                                   l0,
                                   l2,
                                   0 );
            mv.visitLocalVariable( "index",
                                   Type.getDescriptor( int.class ),
                                   null,
                                   l0,
                                   l2,
                                   1 );
            mv.visitLocalVariable( "fieldType",
                                   Type.getDescriptor( Class.class ),
                                   null,
                                   l0,
                                   l2,
                                   2 );
            mv.visitLocalVariable( "valueType",
                                   Type.getDescriptor( ValueType.class ),
                                   null,
                                   l0,
                                   l2,
                                   3 );
            mv.visitMaxs( 0,
                          0 );
            mv.visitEnd();
        }
    }

    /**
     * Creates the proxy reader method for the given method
     *
     * @param fieldName
     * @param fieldFlag
     * @param method
     * @param cw
     */
    protected static void buildGetMethod(final Class originalClass,
                                         final String className,
                                         final Class superClass,
                                         final Method getterMethod,
                                         final ClassWriter cw) {

        final Class fieldType = getterMethod.getReturnType();
        Method overridingMethod;
        try {
            overridingMethod = superClass.getMethod( getOverridingMethodName( fieldType ),
                                                     new Class[]{InternalWorkingMemory.class, Object.class} );
        } catch ( final Exception e ) {
            throw new RuntimeDroolsException( "This is a bug. Please report back to JBoss Rules team.",
                                              e );
        }
        final MethodVisitor mv = cw.visitMethod( Opcodes.ACC_PUBLIC,
                                                 overridingMethod.getName(),
                                                 Type.getMethodDescriptor( overridingMethod ),
                                                 null,
                                                 null );

        mv.visitCode();

        final Label l0 = new Label();
        mv.visitLabel( l0 );
        mv.visitVarInsn( Opcodes.ALOAD,
                         2 );
        mv.visitTypeInsn( Opcodes.CHECKCAST,
                          Type.getInternalName( originalClass ) );

        if ( originalClass.isInterface() ) {
            mv.visitMethodInsn( Opcodes.INVOKEINTERFACE,
                                Type.getInternalName( originalClass ),
                                getterMethod.getName(),
                                Type.getMethodDescriptor( getterMethod ) );
        } else {
            mv.visitMethodInsn( Opcodes.INVOKEVIRTUAL,
                                Type.getInternalName( originalClass ),
                                getterMethod.getName(),
                                Type.getMethodDescriptor( getterMethod ) );
        }
        mv.visitInsn( Type.getType( fieldType ).getOpcode( Opcodes.IRETURN ) );
        final Label l1 = new Label();
        mv.visitLabel( l1 );
        mv.visitLocalVariable( "this",
                               "L" + className + ";",
                               null,
                               l0,
                               l1,
                               0 );
        mv.visitLocalVariable( "workingMemory",
                               Type.getDescriptor( InternalWorkingMemory.class ),
                               null,
                               l0,
                               l1,
                               1 );
        mv.visitLocalVariable( "object",
                               Type.getDescriptor( Object.class ),
                               null,
                               l0,
                               l1,
                               2 );
        mv.visitMaxs( 0,
                      0 );
        mv.visitEnd();
    }

    private static String getOverridingMethodName(final Class fieldType) {
        String ret = null;
        if ( fieldType.isPrimitive() ) {
            if ( fieldType == char.class ) {
                ret = "getCharValue";
            } else if ( fieldType == byte.class ) {
                ret = "getByteValue";
            } else if ( fieldType == short.class ) {
                ret = "getShortValue";
            } else if ( fieldType == int.class ) {
                ret = "getIntValue";
            } else if ( fieldType == long.class ) {
                ret = "getLongValue";
            } else if ( fieldType == float.class ) {
                ret = "getFloatValue";
            } else if ( fieldType == double.class ) {
                ret = "getDoubleValue";
            } else if ( fieldType == boolean.class ) {
                ret = "getBooleanValue";
            }
        } else {
            ret = "getValue";
        }
        return ret;
    }

    /**
     * Returns the appropriate Base class field extractor class
     * for the given fieldType
     *
     * @param fieldType
     * @return
     */
    private static Class getSuperClassFor(final Class fieldType) {
        Class ret = null;
        if ( fieldType.isPrimitive() ) {
            if ( fieldType == char.class ) {
                ret = BaseCharClassFieldExtractor.class;
            } else if ( fieldType == byte.class ) {
                ret = BaseByteClassFieldExtractor.class;
            } else if ( fieldType == short.class ) {
                ret = BaseShortClassFieldExtractor.class;
            } else if ( fieldType == int.class ) {
                ret = BaseIntClassFieldExtractor.class;
            } else if ( fieldType == long.class ) {
                ret = BaseLongClassFieldExtractors.class;
            } else if ( fieldType == float.class ) {
                ret = BaseFloatClassFieldExtractor.class;
            } else if ( fieldType == double.class ) {
                ret = BaseDoubleClassFieldExtractor.class;
            } else if ( fieldType == boolean.class ) {
                ret = BaseBooleanClassFieldExtractor.class;
            }
        } else {
            ret = BaseObjectClassFieldExtractor.class;
        }
        return ret;
    }

    /**
     * Simple classloader
     * @author Michael Neale
     */
    static class ByteArrayClassLoader extends ClassLoader {
        public ByteArrayClassLoader(final ClassLoader parent) {
            super( parent );
        }

        public Class defineClass(final String name,
                                 final byte[] bytes,
                                 final ProtectionDomain domain) {
            return defineClass( name,
                                bytes,
                                0,
                                bytes.length,
                                domain );
        }
    }
}
TOP

Related Classes of org.drools.base.ClassFieldExtractorFactory$ByteArrayClassLoader

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.