Package org.trifort.rootbeer.test

Source Code of org.trifort.rootbeer.test.RootbeerTestAgent

/*
* Copyright 2012 Phil Pratt-Szeliga and other contributors
* http://chirrup.org/
*
* See the file LICENSE for copying permission.
*/

package org.trifort.rootbeer.test;

import java.io.ByteArrayOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;

import org.trifort.rootbeer.configuration.Configuration;
import org.trifort.rootbeer.runtime.Context;
import org.trifort.rootbeer.runtime.Kernel;
import org.trifort.rootbeer.runtime.Rootbeer;
import org.trifort.rootbeer.runtime.RootbeerGpu;
import org.trifort.rootbeer.runtime.ThreadConfig;
import org.trifort.rootbeer.runtime.util.Stopwatch;
import org.trifort.rootbeer.util.ForceGC;

public class RootbeerTestAgent {

  private long m_cpuTime;
  private long m_gpuTime;
  private boolean m_passed;
  private String m_message;
  private List<String> m_failedTests;
 
  public RootbeerTestAgent(){
    m_failedTests = new ArrayList<String>();
  }
 
  public void testOne(ClassLoader cls_loader, String test_case) throws Exception {
    Class test_case_cls = cls_loader.loadClass(test_case);
    Object test_case_obj = test_case_cls.newInstance();
    if(test_case_obj instanceof TestSerialization){
      TestSerialization test_ser = (TestSerialization) test_case_obj;
      System.out.println("[TEST 1/1] "+test_ser.toString());
      if(test_case.equals("org.trifort.rootbeer.testcases.rootbeertest.gpurequired.ChangeThreadTest")){
        testChangeThread(test_ser, true);
      } else {
        test(test_ser, true);
      }
      if(m_passed){
        System.out.println("  PASSED");
        System.out.println("  Cpu time: "+m_cpuTime+" ms");
        System.out.println("  Gpu time: "+m_gpuTime+" ms");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+m_message);
      }  
    } else if(test_case_obj instanceof TestException){
      TestException test_ex = (TestException) test_case_obj;
      System.out.println("[TEST 1/1] "+test_ex.toString());
      ex_test(test_ex, true);
      if(m_passed){
        System.out.println("  PASSED");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+m_message);
      }       
    } else if(test_case_obj instanceof TestKernelTemplate){
      TestKernelTemplate test_kernel_template = (TestKernelTemplate) test_case_obj;
      System.out.println("[TEST 1/1] "+test_kernel_template.toString());
      test(test_kernel_template, true);
      if(m_passed){
        System.out.println("  PASSED");
        System.out.println("  Cpu time: "+m_cpuTime+" ms");
        System.out.println("  Gpu time: "+m_gpuTime+" ms");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+m_message);
      }  
    } else if(test_case_obj instanceof TestApplication){
      TestApplication test_application = (TestApplication) test_case_obj;
      System.out.println("[TEST 1/1] "+test_application.toString());
      if(test_application.test()){
        System.out.println("  PASSED");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+test_application.errorMessage());
      }
    } else {
      throw new RuntimeException("unknown test case type");
    }
  }
 
  public void test(ClassLoader cls_loader, boolean run_hard_tests) throws Exception {
    LoadTestSerialization loader = new LoadTestSerialization();
    List<TestSerialization> creators = loader.load(cls_loader, "org.trifort.rootbeer.test.Main", run_hard_tests);
    List<TestException> ex_creators = loader.loadException(cls_loader, "org.trifort.rootbeer.test.ExMain");
    List<TestSerialization> change_thread = loader.load(cls_loader, "org.trifort.rootbeer.test.ChangeThread", run_hard_tests);
    List<TestKernelTemplate> kernel_template_creators = loader.loadKernelTemplate(cls_loader, "org.trifort.rootbeer.test.KernelTemplateMain");
    List<TestApplication> application_creators = loader.loadApplication(cls_loader, "org.trifort.rootbeer.test.ApplicationMain");
    int num_tests = creators.size() + ex_creators.size() + change_thread.size() +
      kernel_template_creators.size() + application_creators.size();
    int test_num = 1;

    for(TestSerialization creator : creators){
      System.out.println("[TEST "+test_num+"/"+num_tests+"] "+creator.toString());
      test(creator, false);
      ForceGC.gc();
      if(m_passed){
        System.out.println("  PASSED");
        System.out.println("  Cpu time: "+m_cpuTime+" ms");
        System.out.println("  Gpu time: "+m_gpuTime+" ms");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+m_message);
        m_failedTests.add(creator.toString());
      }       
      ++test_num;
    }

    for(TestException ex_creator : ex_creators){
      System.out.println("[TEST "+test_num+"/"+num_tests+"] "+ex_creator.toString());
      ex_test(ex_creator, false);
      if(m_passed){
        System.out.println("  PASSED");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+m_message);
        m_failedTests.add(ex_creator.toString());
      }       
      ++test_num;
    }
   
    for(TestSerialization creator : change_thread){
      System.out.println("[TEST "+test_num+"/"+num_tests+"] "+creator.toString());
      testChangeThread(creator, false);
      if(m_passed){
        System.out.println("  PASSED");
        System.out.println("  Cpu time: "+m_cpuTime+" ms");
        System.out.println("  Gpu time: "+m_gpuTime+" ms");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+m_message);
        m_failedTests.add(creator.toString());
      }       
      ++test_num;
    }
   
    for(TestKernelTemplate kernel_template : kernel_template_creators){
      System.out.println("[TEST "+test_num+"/"+num_tests+"] "+kernel_template.toString());
      test(kernel_template, false);
      ForceGC.gc();
      if(m_passed){
        System.out.println("  PASSED");
        System.out.println("  Cpu time: "+m_cpuTime+" ms");
        System.out.println("  Gpu time: "+m_gpuTime+" ms");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+m_message);
        m_failedTests.add(kernel_template.toString());
      }       
      ++test_num;
    }
   
    for(TestApplication application : application_creators){
      System.out.println("[TEST "+test_num+"/"+num_tests+"] "+application.toString());
      if(application.test()){
        System.out.println("  PASSED");
      } else {
        System.out.println("  FAILED");
        System.out.println("  "+application.errorMessage());
        m_failedTests.add(application.toString());
      }
      ++test_num;
    }

   
    int test_passed = num_tests - m_failedTests.size();
    System.out.println(test_passed+"/"+num_tests+" tests PASS");
    if(test_passed == num_tests){
      System.out.println("ALL TESTS PASS!");
    } else {
      System.out.println("Failing tests:");
      for(String failure : m_failedTests){
        System.out.println("  "+failure);
      }
    }
  }
 
  private void test(TestSerialization creator, boolean print_mem) {
    int i = 0;
    try {     
      Rootbeer rootbeer = new Rootbeer();
      Configuration.setPrintMem(print_mem);
      List<Kernel> known_good_items = creator.create();
      List<Kernel> testing_items = creator.create();
      Stopwatch watch = new Stopwatch();
      watch.start();
      rootbeer.run(testing_items);
      m_passed = true;
      watch.stop();
      m_gpuTime = watch.elapsedTimeMillis();
      watch.start();
      for(i = 0; i < known_good_items.size(); ++i){      
        Kernel known_good_item = known_good_items.get(i);
        known_good_item.gpuMethod();
      }
      watch.stop();
      m_cpuTime = watch.elapsedTimeMillis();
      for(i = 0; i < known_good_items.size(); ++i){
        Kernel known_good_item = known_good_items.get(i);
        Kernel testing_item = testing_items.get(i);
        if(!creator.compare(known_good_item, testing_item)){
          m_message = "Compare failed at: "+i;
          m_passed = false;
          return;
        }
      }
    } catch(Throwable ex){
      ex.printStackTrace(System.out);
      m_message = "Exception thrown at index: "+i;
      m_passed = false;
    }
  }

  private void test(TestKernelTemplate creator, boolean print_mem) {
    int i = 0;
    try {     
      Rootbeer rootbeer = new Rootbeer();
      Configuration.setPrintMem(print_mem);
      Kernel known_good_item = creator.create();
      Kernel testing_item = creator.create();
      ThreadConfig thread_config = creator.getThreadConfig();

      Stopwatch watch = new Stopwatch();
      watch.start();
      Context context = rootbeer.createDefaultContext();
      context.setKernel(testing_item);
      context.setThreadConfig(thread_config);
      context.buildState();
      context.run();
      context.close();
      m_passed = true;
      watch.stop();
      m_gpuTime = watch.elapsedTimeMillis();
      watch.start();
      RootbeerGpu.setBlockDimx(thread_config.getThreadCountX());
      RootbeerGpu.setBlockDimy(thread_config.getThreadCountY());
      RootbeerGpu.setBlockDimz(thread_config.getThreadCountZ());
      RootbeerGpu.setGridDimx(thread_config.getBlockCountX());
      RootbeerGpu.setGridDimy(thread_config.getBlockCountY());
     
      for(int blockx = 0; blockx < thread_config.getBlockCountX(); ++blockx){
        RootbeerGpu.setBlockIdxx(blockx);
        for(int blocky = 0; blocky < thread_config.getBlockCountY(); ++blocky){
          RootbeerGpu.setBlockIdxy(blocky);
          for(int threadx = 0; threadx < thread_config.getThreadCountX(); ++threadx){
            RootbeerGpu.setThreadIdxx(threadx);
            for(int thready = 0; thready < thread_config.getThreadCountY(); ++thready){
              RootbeerGpu.setThreadIdxy(thready);
              for(int threadz = 0; threadz < thread_config.getThreadCountZ(); ++threadz){
                RootbeerGpu.setThreadIdxz(threadz);
                known_good_item.gpuMethod();
              }
            }
          }
        }
      }
      watch.stop();
      m_cpuTime = watch.elapsedTimeMillis();
      if(!creator.compare(known_good_item, testing_item)){
        m_message = "Compare failed at: "+i;
        m_passed = false;
        return;
      }
    } catch(Throwable ex){
      ex.printStackTrace(System.out);
      m_message = "Exception thrown at index: "+i;
      m_passed = false;
    }
  }
 
  private void ex_test(TestException creator, boolean print_mem) {
    Rootbeer rootbeer = new Rootbeer();
    Configuration.setPrintMem(print_mem);
    List<Kernel> testing_items = creator.create();
    try {
      rootbeer.run(testing_items);
      m_passed = false;
      m_message = "No exception thrown when expecting one.";
    } catch(Throwable ex){
      m_passed = creator.catchException(ex);
      if(m_passed == false){
        m_message = "Exception is: "+ex.toString();
      }
    }
  }

  private void testChangeThread(TestSerialization creator, boolean print_mem) {
    Thread t = new Thread(new ChangeThread(creator, print_mem));
    t.start();
    try {
      t.join();
    } catch(Exception ex){
      ex.printStackTrace();
    }
  }
 
  private class ChangeThread implements Runnable {

    private TestSerialization m_creator;
    private boolean m_printMem;
    private Rootbeer m_rootbeer;
   
    public ChangeThread(TestSerialization creator, boolean print_mem){
      m_creator = creator;
      m_printMem = print_mem;
      m_rootbeer = new Rootbeer();
    }
   
    public void run() {
      int i = 0;
      try {     
        Configuration.setPrintMem(m_printMem);
        List<Kernel> known_good_items = m_creator.create();
        List<Kernel> testing_items = m_creator.create();
        Stopwatch watch = new Stopwatch();
        watch.start();
        m_rootbeer.run(testing_items);
        m_passed = true;
        watch.stop();
        m_gpuTime = watch.elapsedTimeMillis();
        watch.start();
        for(i = 0; i < known_good_items.size(); ++i){      
          Kernel known_good_item = known_good_items.get(i);
          known_good_item.gpuMethod();
        }
        watch.stop();
        m_cpuTime = watch.elapsedTimeMillis();
        for(i = 0; i < known_good_items.size(); ++i){
          Kernel known_good_item = known_good_items.get(i);
          Kernel testing_item = testing_items.get(i);
          if(!m_creator.compare(known_good_item, testing_item)){
            m_message = "Compare failed at: "+i;
            m_passed = false;
            return;
          }
        }
      } catch(Throwable ex){
        ex.printStackTrace(System.out);
        m_message = "Exception thrown at index: "+i;
        m_passed = false;
      }
    }
  }
}
TOP

Related Classes of org.trifort.rootbeer.test.RootbeerTestAgent

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.