Package org.trifort.rootbeer.generate.opencl

Source Code of org.trifort.rootbeer.generate.opencl.OpenCLScene

/*
* Copyright 2012 Phil Pratt-Szeliga and other contributors
* http://chirrup.org/
*
* See the file LICENSE for copying permission.
*/

package org.trifort.rootbeer.generate.opencl;

import soot.jimple.NewExpr;
import soot.rbclassload.MethodSignatureUtil;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.trifort.rootbeer.configuration.Configuration;
import org.trifort.rootbeer.configuration.RootbeerPaths;
import org.trifort.rootbeer.entry.ForcedFields;
import org.trifort.rootbeer.entry.CompilerSetup;
import org.trifort.rootbeer.generate.bytecode.MethodCodeSegment;
import org.trifort.rootbeer.generate.bytecode.ReadOnlyTypes;
import org.trifort.rootbeer.generate.opencl.fields.CompositeField;
import org.trifort.rootbeer.generate.opencl.fields.CompositeFieldFactory;
import org.trifort.rootbeer.generate.opencl.fields.FieldCodeGeneration;
import org.trifort.rootbeer.generate.opencl.fields.FieldTypeSwitch;
import org.trifort.rootbeer.generate.opencl.fields.OffsetCalculator;
import org.trifort.rootbeer.generate.opencl.fields.OpenCLField;
import org.trifort.rootbeer.generate.opencl.tweaks.CompileResult;
import org.trifort.rootbeer.generate.opencl.tweaks.CudaTweaks;
import org.trifort.rootbeer.generate.opencl.tweaks.Tweaks;
import org.trifort.rootbeer.util.ReadFile;
import org.trifort.rootbeer.util.ResourceReader;

import soot.*;
import soot.rbclassload.FieldSignatureUtil;
import soot.rbclassload.NumberedType;
import soot.rbclassload.RootbeerClassLoader;

public class OpenCLScene {
  private static OpenCLScene m_instance;
  private static int m_curentIdent;
  private Map<String, OpenCLClass> m_classes;
  private Set<OpenCLArrayType> m_arrayTypes;
  private MethodHierarchies m_methodHierarchies;
  private boolean m_usesGarbageCollector;
  private SootClass m_rootSootClass;
  private int m_endOfStatics;
  private ReadOnlyTypes m_readOnlyTypes;
  private Set<OpenCLInstanceof> m_instanceOfs;
  private List<CompositeField> m_compositeFields;
  private List<SootMethod> m_methods;
  private ClassConstantNumbers m_constantNumbers;
  private FieldCodeGeneration m_fieldCodeGeneration;
 
  static {
    m_curentIdent = 0;
  }

  public OpenCLScene(){
  }
 
  public void init(){
    m_classes = new LinkedHashMap<String, OpenCLClass>();
    m_arrayTypes = new LinkedHashSet<OpenCLArrayType>();
    m_methodHierarchies = new MethodHierarchies();
    m_instanceOfs = new HashSet<OpenCLInstanceof>();
    m_methods = new ArrayList<SootMethod>();
    m_constantNumbers = new ClassConstantNumbers();
    m_fieldCodeGeneration = new FieldCodeGeneration();
    loadTypes();
  }

  public static OpenCLScene v(){
    return m_instance;
  }
 
  public static void setInstance(OpenCLScene scene){
    m_instance = scene;
  }

  public static void releaseV(){
    m_instance = null;
    m_curentIdent++;
  }
 
  public String getIdent(){
    return "" + m_curentIdent;
  }

  public String getUuid(){
    return "ab850b60f96d11de8a390800200c9a66";
  }

  public int getEndOfStatics(){
    return m_endOfStatics;
  }

  public int getClassType(SootClass soot_class){
    return RootbeerClassLoader.v().getClassNumber(soot_class);
  }
 
  public void addMethod(SootMethod soot_method){
    SootClass soot_class = soot_method.getDeclaringClass();

    OpenCLClass ocl_class = getOpenCLClass(soot_class);
    ocl_class.addMethod(new OpenCLMethod(soot_method, soot_class));

    //add the method
    m_methodHierarchies.addMethod(soot_method);
    m_methods.add(soot_method);
  }
 
  public List<SootMethod> getMethods(){
    return m_methods;
  }

  public void addArrayType(OpenCLArrayType array_type){
    if(m_arrayTypes.contains(array_type))
      return;
    m_arrayTypes.add(array_type);
 
 
  public void addInstanceof(Type type){
    OpenCLInstanceof to_add = new OpenCLInstanceof(type);
    if(m_instanceOfs.contains(to_add) == false){
      m_instanceOfs.add(to_add);
    }
  }

  public OpenCLClass getOpenCLClass(SootClass soot_class){   
    String class_name = soot_class.getName();
    if(m_classes.containsKey(class_name)){
      return m_classes.get(class_name);
    } else {
      OpenCLClass ocl_class = new OpenCLClass(soot_class);
      m_classes.put(class_name, ocl_class);
      return ocl_class;
    }
  }

  public void addField(SootField soot_field){
    SootClass soot_class = soot_field.getDeclaringClass();
    OpenCLClass ocl_class = getOpenCLClass(soot_class);
    ocl_class.addField(new OpenCLField(soot_field, soot_class));
  }

  private String getRuntimeBasicBlockClassName(){
    SootClass soot_class = m_rootSootClass;
    OpenCLClass ocl_class = getOpenCLClass(soot_class);
    return ocl_class.getName();
  }

  private String readCudaCodeFromFile(){
    try {
      BufferedReader reader = new BufferedReader(new FileReader("generated.cu"));
      String ret = "";
      while(true){
        String temp = reader.readLine();
        if(temp == null)
          return ret;
        ret += temp+"\n";
      }
    } catch(Exception ex){
      throw new RuntimeException();
    }
  }

  public void setUsingGarbageCollector(){
    m_usesGarbageCollector = true;
  }

  public boolean getUsingGarbageCollector(){
    return m_usesGarbageCollector;
  }
 
  private void writeTypesToFile(List<NumberedType> types){
    try {
      PrintWriter writer = new PrintWriter(RootbeerPaths.v().getTypeFile());
      for(NumberedType type : types){
        writer.println(type.getNumber()+" "+type.getType().toString());
      }
      writer.flush();
      writer.close();
    } catch(Exception ex){
      ex.printStackTrace();
    }
  }
 
  public int getOutOfMemoryNumber(){
    SootClass soot_class = Scene.v().getSootClass("java.lang.OutOfMemoryError");
    int ret = RootbeerClassLoader.v().getClassNumber(soot_class);
    return ret;
  }
 
  private void loadTypes(){
    Set<String> methods = RootbeerClassLoader.v().getDfsInfo().getMethods()
    MethodSignatureUtil util = new MethodSignatureUtil();
    for(String method_sig : methods){
      util.parse(method_sig);
      SootMethod method = util.getSootMethod();
      addMethod(method);
    }
    CompilerSetup compiler_setup = new CompilerSetup();
    for(String extra_method : compiler_setup.getExtraMethods()){
      util.parse(extra_method);
      addMethod(util.getSootMethod());
    }
   
    Set<SootField> fields = RootbeerClassLoader.v().getDfsInfo().getFields();
    for(SootField field : fields){
      addField(field);
    }

    FieldSignatureUtil field_util = new FieldSignatureUtil();
    ForcedFields forced_fields = new ForcedFields();
    for(String field_sig : forced_fields.get()){
      field_util.parse(field_sig);
      addField(field_util.getSootField());
    }
   
    Set<ArrayType> array_types = RootbeerClassLoader.v().getDfsInfo().getArrayTypes();
    for(ArrayType array_type : array_types){
      OpenCLArrayType ocl_array_type = new OpenCLArrayType(array_type);
      addArrayType(ocl_array_type);
    }
    for(ArrayType array_type : compiler_setup.getExtraArrayTypes()){
      OpenCLArrayType ocl_array_type = new OpenCLArrayType(array_type);
      addArrayType(ocl_array_type);
    }
   
    Set<Type> instanceofs = RootbeerClassLoader.v().getDfsInfo().getInstanceOfs();
    for(Type type : instanceofs){
      addInstanceof(type);
    }
   
    buildCompositeFields()
  }
 
  private String[] makeSourceCode() throws Exception {
    if(Configuration.compilerInstance().isManualCuda()){
      String filename = Configuration.compilerInstance().getManualCudaFilename();
      String cuda_code = readCode(filename);
         
      String[] ret = new String[2];
      ret[0] = cuda_code;
      ret[1] = cuda_code;
      return ret;
    }
   
    m_usesGarbageCollector = false;
   
    List<NumberedType> types = RootbeerClassLoader.v().getDfsInfo().getNumberedTypes();
    writeTypesToFile(types);
       
    StringBuilder unix_code = new StringBuilder();
    StringBuilder windows_code = new StringBuilder();
   
    String method_protos = methodPrototypesString();
    String gc_string = garbageCollectorString();
    String bodies_string = methodBodiesString();
   
    unix_code.append(headerString(true));
    unix_code.append(method_protos);
    unix_code.append(gc_string);
    unix_code.append(bodies_string);
    unix_code.append(kernelString(true));

    windows_code.append(headerString(false));
    windows_code.append(method_protos);
    windows_code.append(gc_string);
    windows_code.append(bodies_string);
    windows_code.append(kernelString(false));
   
    String cuda_unix = setupEntryPoint(unix_code);
    String cuda_windows = setupEntryPoint(windows_code);
   
    //print out code for debugging
    PrintWriter writer = new PrintWriter(new FileWriter(RootbeerPaths.v().getRootbeerHome()+"generated_unix.cu"));
    writer.println(cuda_unix);
    writer.flush();
    writer.close();
   
    //print out code for debugging
    writer = new PrintWriter(new FileWriter(RootbeerPaths.v().getRootbeerHome()+"generated_windows.cu"));
    writer.println(cuda_windows);
    writer.flush();
    writer.close();
   
    NameMangling.v().writeTypesToFile();
       
    String[] ret = new String[2];
    ret[0] = cuda_unix;
    ret[1] = cuda_windows;
    return ret;
  }
 
  private String readCode(String filename){
    ReadFile reader = new ReadFile(filename);
    try {
      return reader.read();
    } catch(Exception ex){
      ex.printStackTrace(System.out);
      throw new RuntimeException(ex);
    }
  }

  private String setupEntryPoint(StringBuilder builder){
    String cuda_code = builder.toString();
    String mangle = NameMangling.v().mangle(VoidType.v());
    String replacement = getRuntimeBasicBlockClassName()+"_gpuMethod"+mangle;
    //class names can have $ in them, make them regex safe
    replacement = replacement.replace("$", "\\$");
    cuda_code = cuda_code.replaceAll("%%invoke_run%%", replacement)
   
    int string_builder_number = RootbeerClassLoader.v().getClassNumber("java.lang.StringBuilder");
    String sbn_str = "" + string_builder_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_StringBuilder_TypeNumber%%", sbn_str);
   
    int null_pointer_number = RootbeerClassLoader.v().getClassNumber("java.lang.NullPointerException");
    String np_str = "" + null_pointer_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_NullPointerException_TypeNumber%%", np_str);

    int out_of_memory_number = RootbeerClassLoader.v().getClassNumber("java.lang.OutOfMemoryError");
    String out_of_memory_str = "" + out_of_memory_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_OutOfMemoryError_TypeNumber%%", out_of_memory_str);
   
    int size = Configuration.compilerInstance().getSharedMemSize();
    String size_str = ""+size;
    cuda_code = cuda_code.replaceAll("%%shared_mem_size%%", size_str);
   
    boolean exceptions = Configuration.compilerInstance().getExceptions();
    String exceptions_str;
    if(exceptions){
      exceptions_str = ""+1;
    } else {
      exceptions_str = ""+0;
    }
    cuda_code = cuda_code.replaceAll("%%using_exceptions%%", exceptions_str);
   
    int string_number = RootbeerClassLoader.v().getClassNumber("java.lang.String");
    String string_str = "" + string_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_String_TypeNumber%%", string_str);

    int integer_number = RootbeerClassLoader.v().getClassNumber("java.lang.Integer");
    String integer_str = "" + integer_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_Integer_TypeNumber%%", integer_str);
   
    int long_number = RootbeerClassLoader.v().getClassNumber("java.lang.Long");
    String long_str = "" + long_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_Long_TypeNumber%%", long_str);
   
    int float_number = RootbeerClassLoader.v().getClassNumber("java.lang.Float");
    String float_str = "" + float_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_Float_TypeNumber%%", float_str);
   
    int double_number = RootbeerClassLoader.v().getClassNumber("java.lang.Double");
    String double_str = "" + double_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_Double_TypeNumber%%", double_str);

    int boolean_number = RootbeerClassLoader.v().getClassNumber("java.lang.Boolean");
    String boolean_str = "" + boolean_number;
    cuda_code = cuda_code.replaceAll("%%java_lang_Boolean_TypeNumber%%", boolean_str);
   
    return cuda_code;
  }
 
  public String[] getOpenCLCode() throws Exception {
    String[] source_code = makeSourceCode();
    return source_code;
  }

  public CompileResult[] getCudaCode() throws Exception {
    String[] source_code = makeSourceCode();
    return new CudaTweaks().compileProgram(source_code[0], Configuration.compilerInstance().getCompileArchitecture());
  }

  private String headerString(boolean unix) throws IOException {
    String defines = "";
    if(Configuration.compilerInstance().getArrayChecks()){
      defines += "#define ARRAY_CHECKS\n";
    }
   
    String specific_path;
    if(unix){
      specific_path = Tweaks.v().getUnixHeaderPath();
    } else {
      specific_path = Tweaks.v().getWindowsHeaderPath();
    }
    if(specific_path == null)
      return "";
    String both_path = Tweaks.v().getBothHeaderPath();
    String both_header = "";
    if(both_path != null){
      both_header = ResourceReader.getResource(both_path);
    }
    String specific_header = ResourceReader.getResource(specific_path);
   
    String barrier_path = Tweaks.v().getBarrierPath();
    String barrier_code = "";
    if(barrier_path != null){
      barrier_code = ResourceReader.getResource(barrier_path);
    }
   
    return defines + "\n" + specific_header + "\n" + both_header + "\n" + barrier_code;
  }
 
  private String kernelString(boolean unix) throws IOException {
    String kernel_path;
    if(unix){
      kernel_path = Tweaks.v().getUnixKernelPath();
    } else {
      kernel_path = Tweaks.v().getWindowsKernelPath();
    }
    String specific_kernel_code = ResourceReader.getResource(kernel_path);
    String both_kernel_code = "";
    String both_kernel_path = Tweaks.v().getBothKernelPath();
    if(both_kernel_path != null){
      both_kernel_code = ResourceReader.getResource(both_kernel_path);
    }
    return both_kernel_code + "\n" + specific_kernel_code;
  }
 
  private String garbageCollectorString() throws IOException {
    String path = Tweaks.v().getGarbageCollectorPath();
    String ret = ResourceReader.getResource(path);
    ret = ret.replace("$$__device__$$", Tweaks.v().getDeviceFunctionQualifier());
    ret = ret.replace("$$__global$$", Tweaks.v().getGlobalAddressSpaceQualifier());
    return ret;
  }
 
  private String methodPrototypesString(){
    //using a set so duplicates get filtered out.
    Set<String> protos = new HashSet<String>();
    StringBuilder ret = new StringBuilder();
   
    ArrayCopyGenerate arr_generate = new ArrayCopyGenerate();
    protos.add(arr_generate.getProto());
   
    List<OpenCLMethod> methods = m_methodHierarchies.getMethods();
    for(OpenCLMethod method : methods){
      protos.add(method.getMethodPrototype());
    }   
    List<OpenCLPolymorphicMethod> poly_methods = m_methodHierarchies.getPolyMorphicMethods();
    for(OpenCLPolymorphicMethod poly_method : poly_methods){
      protos.add(poly_method.getMethodPrototypes());
    }
    protos.add(m_fieldCodeGeneration.prototypes(m_classes));
    for(OpenCLArrayType array_type : m_arrayTypes){
      protos.add(array_type.getPrototypes());
    }
    for(OpenCLInstanceof type : m_instanceOfs){
      protos.add(type.getPrototype());
    }
    Iterator<String> iter = protos.iterator();
    while(iter.hasNext()){
      ret.append(iter.next());
    }
    return ret.toString();
  }

  private String methodBodiesString() throws IOException{
    StringBuilder ret = new StringBuilder();
    if(m_usesGarbageCollector){
      ret.append("#define USING_GARBAGE_COLLECTOR\n");
    }
   
    //a set is used so duplicates get filtered out
    Set<String> bodies = new HashSet<String>();
   
    ArrayCopyTypeReduction reduction = new ArrayCopyTypeReduction();
    Set<OpenCLArrayType> new_types = reduction.run(m_arrayTypes, m_methodHierarchies);
   
    ArrayCopyGenerate arr_generate = new ArrayCopyGenerate();
    bodies.add(arr_generate.get(new_types));
   
    List<OpenCLMethod> methods = m_methodHierarchies.getMethods();
    for(OpenCLMethod method : methods){
      bodies.add(method.getMethodBody());
    }
    List<OpenCLPolymorphicMethod> poly_methods = m_methodHierarchies.getPolyMorphicMethods();
    for(OpenCLPolymorphicMethod poly_method : poly_methods){
      bodies.addAll(poly_method.getMethodBodies());
    }
    FieldTypeSwitch type_switch = new FieldTypeSwitch();
    String field_bodies = m_fieldCodeGeneration.bodies(m_classes, type_switch);
    bodies.add(field_bodies);
    for(OpenCLArrayType array_type : m_arrayTypes){
      bodies.add(array_type.getBodies());
    }
    for(OpenCLInstanceof type : m_instanceOfs){
      bodies.add(type.getBody());
    }
    Iterator<String> iter = bodies.iterator();
    ret.append(type_switch.getFunctions());
    while(iter.hasNext()){
      ret.append(iter.next());
    }
    return ret.toString();
  }
 
  public OffsetCalculator getOffsetCalculator(SootClass soot_class){
    List<CompositeField> composites = getCompositeFields();
    for(CompositeField composite : composites){
      List<SootClass> classes = composite.getClasses();
      if(classes.contains(soot_class))
        return new OffsetCalculator(composite);
    }
    throw new RuntimeException("Cannot find composite field for soot_class");
  }

  public void addCodeSegment(MethodCodeSegment codeSegment){
    m_rootSootClass = codeSegment.getRootSootClass();   
    m_readOnlyTypes = new ReadOnlyTypes(codeSegment.getRootMethod());
    getOpenCLClass(m_rootSootClass);
  }

  public boolean isArrayLocalWrittenTo(Local local){
    return true;
  }
 
  public ReadOnlyTypes getReadOnlyTypes(){
    return m_readOnlyTypes;
  }

  public boolean isRootClass(SootClass soot_class) {
    return soot_class.getName().equals(m_rootSootClass.getName());
  }

  public Map<String, OpenCLClass> getClassMap(){
    return m_classes;
  }

  public List<CompositeField> getCompositeFields() {
    return m_compositeFields;
  }

  private void buildCompositeFields() {
    CompositeFieldFactory factory = new CompositeFieldFactory();
    factory.setup(m_classes);
    m_compositeFields = factory.getCompositeFields();
  }
 
  public ClassConstantNumbers getClassConstantNumbers(){
    return m_constantNumbers;
  }
}
TOP

Related Classes of org.trifort.rootbeer.generate.opencl.OpenCLScene

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.