Package upenn.junto.app

Source Code of upenn.junto.app.ConfigTuner

package upenn.junto.app;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Hashtable;
import java.util.Iterator;

import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.hash.TObjectDoubleHashMap;

import upenn.junto.config.ConfigReader;
import upenn.junto.util.CollectionUtil;
import upenn.junto.util.Constants;
import upenn.junto.util.Defaults;
import upenn.junto.util.MessagePrinter;

public class ConfigTuner {

  private static ArrayList<Hashtable>
    GetAllCombinations(Hashtable tuningConfig) {
    ArrayList<Hashtable> configs = new ArrayList<Hashtable>();

    Iterator iter = tuningConfig.keySet().iterator();
    while (iter.hasNext()) {
      String paramKey = (String) iter.next();
      String paramVal = (String) tuningConfig.get(paramKey);
     
      // e.g. mu1 = 1e-8,1,1e-8
      String[] fields = paramVal.split(",");
      int currSize = configs.size();
      for (int fi = 0; fi < fields.length; ++fi) {
        // add the first configuration, if none exists
        if (configs.size() == 0) {
          configs.add(new Hashtable());
          ++currSize;
        }
        for (int ci = 0; ci < currSize; ++ci) {
          // the first value can be added to existing
          // configurations.
          if (fi == 0) {
            configs.get(ci).put(paramKey, fields[fi]);
          } else {
            Hashtable nc = (Hashtable) configs.get(ci).clone();
            nc.put(paramKey, fields[fi]);
           
            // append the new config to the end of the list
            configs.add(nc);
          }
        }
      }
    }

    System.out.println("Total config (non-unique) combinations: " + configs.size());
    return (configs);
  }

  private static void Run(Hashtable tuningConfig) {
    // some essential options terminate if they are note specified
    String idenStr = Defaults.GetValueOrDie(tuningConfig, "iden_str");
    String logDir = Defaults.GetValueOrDie(tuningConfig, "log_output_dir");
    String opDir = Defaults.GetValueOrDefault(
                                           (String) tuningConfig.get("output_dir"), null);
    boolean skipExistingConfigs =
      Defaults.GetValueOrDefault((String) tuningConfig.get("skip_existing_config"), false);

    // config file with post-tuning testing details (i.e. final test file etc.)
    String finalTestConfigFile = (String) tuningConfig.get("final_config_file");
    tuningConfig.remove("final_config_file");

    // generate all possible combinations (non unique)
    ArrayList<Hashtable> configs = GetAllCombinations(tuningConfig);
   
    ArrayList<ArrayList> results = new ArrayList<ArrayList>();
    HashSet<String> uniqueConfigs = new HashSet<String>();
   
    // map from algo to the current best scores and the corresponding config
    HashMap<String,Hashtable> algo2BestConfig = new HashMap<String,Hashtable>();
    TObjectDoubleHashMap algo2BestScore = new TObjectDoubleHashMap();
   
    // store console
    PrintStream consoleOut = System.out;
    PrintStream consoleErr = System.err;

    for (int ci = 0; ci < configs.size(); ++ci) {
      Hashtable c = configs.get(ci);
     
      // if this a post-tune config, then generate seed and test files
      if (Defaults.GetValueOrDefault((String) c.get("is_final_run"), false)) {
        String splitId = Defaults.GetValueOrDie(c, "split_id");
        c.put("seed_file", c.remove("seed_base") + "." + splitId + ".train");
        c.put("test_file", c.remove("test_base") + "." + splitId + ".test");
      }
     
      // output file name is considered a unique identifier of a configuration
      String outputFile = GetOutputFileName(c, opDir, idenStr);
      if (uniqueConfigs.contains(outputFile)) {
        continue;
      }
      uniqueConfigs.add(outputFile);
      if (opDir != null) {
        c.put("output_file", outputFile);
      }

      System.out.println("Working with config: " + c.toString());

      try {
        // reset System.out so that the log printed using System.out.println
        // is directed to the right log file
        String logFile = GetLogFileName(c, logDir, idenStr);
       
        // if the log file exists, then don't repeat
        File lf = new File(logFile);
        if (skipExistingConfigs && lf.exists()) {
          continue;
        }
       
        FileOutputStream fos = new FileOutputStream(new File(logFile));
        PrintStream ps = new PrintStream(fos);
        System.setOut(ps);
        System.setErr(ps);
     
        results.add(new ArrayList());
        JuntoConfigRunner.apply(c, results.get(results.size() - 1));
        UpdateBestConfig((String) c.get("algo"), algo2BestScore,
                         algo2BestConfig, c, results.get(results.size() - 1));

        // reset System.out back to the original console value
        System.setOut(consoleOut);
        System.setErr(consoleErr);
       
        // close log file
        fos.close();

      } catch (FileNotFoundException fnfe) {
        fnfe.printStackTrace();
      } catch (IOException ioe) {
        ioe.printStackTrace();
      }
    }
   
    // print out the best parameters for each algorithm
    Iterator algoIter = algo2BestConfig.keySet().iterator();
    while (algoIter.hasNext()) {
      String algo = (String) algoIter.next();
      System.out.println("\n#################\n" +
                         "BEST_CONFIG_FOR " + algo + " " +
                         algo2BestScore.get(algo) + "\n" +
                         CollectionUtil.Map2StringPrettyPrint(algo2BestConfig.get(algo)));
     
      // run test with tuned parameters, if requested
      if (finalTestConfigFile != null) {
        Hashtable finalTestConfig = (Hashtable) algo2BestConfig.get(algo).clone();
       
        // add additional config options from the file to the tuned params
        finalTestConfig = ConfigReader.read_config(finalTestConfig, finalTestConfigFile);
        JuntoConfigRunner.apply(finalTestConfig, null);
      }
    }
  }
 
  private static String GetOutputFileName(Hashtable c, String opDir, String idenStr) {
    String outputFile = " ";
    if (c.get("algo").equals("mad") ||
        c.get("algo").equals("lgc") ||
        c.get("algo").equals("am") ||
        c.get("algo").equals("lclp")) {
      outputFile = opDir + "/" + GetBaseName2(c, idenStr);

    } else if (c.get("algo").equals("maddl")) {
      outputFile = opDir + "/" +
        GetBaseName2(c, idenStr) +
        ".mu4_" + c.get("mu4");

    } else if (c.get("algo").equals("adsorption") || c.get("algo").equals("lp_zgl")) {
      outputFile = opDir + "/" + GetBaseName(c, idenStr);
    } else {
      MessagePrinter.PrintAndDie("output_1 file can't be empty!");
    }
    return (outputFile);
  }
 
  private static String GetLogFileName(Hashtable c, String logDir, String idenStr) {
    String logFile = "";
    if (c.get("algo").equals("mad") ||
        c.get("algo").equals("lgc") ||
        c.get("algo").equals("am") ||
        c.get("algo").equals("lclp")) {
      logFile = logDir + "/" + "log." + GetBaseName2(c, idenStr);

    } else if (c.get("algo").equals("maddl")) {
      logFile = logDir + "/" +
        "log." +
        GetBaseName2(c, idenStr) +
        ".mu4_" + c.get("mu4");

    } else if (c.get("algo").equals("adsorption") || c.get("algo").equals("lp_zgl")) {
      logFile = logDir + "/" +
        "log." + GetBaseName(c, idenStr);
    } else {
      MessagePrinter.PrintAndDie("output_2 file can't be empty!");
    }
    return (logFile);
  }
 
  private static String GetBaseName(Hashtable c, String idenStr) {
    String base = idenStr;
           
    if (c.containsKey("max_seeds_per_class")) {
      base += ".spc_" + c.get("max_seeds_per_class");
    }
   
    base += "." + c.get("algo");
    if (c.containsKey("use_bipartite_optimization")) {
      base += ".bipart_opt_" + c.get("use_bipartite_optimization");
    }
    if (c.containsKey("top_k_neighbors")) {
      base += ".K_" + c.get("top_k_neighbors");
    }
    if (c.containsKey("prune_threshold")) {
      base += ".P_" + c.get("prune_threshold");
    }
    if (c.containsKey("high_prune_thresh")) {
      base += ".feat_prune_high_" + c.get("high_prune_thresh");
    }
    if (c.containsKey("keep_top_k_labels")) {
      base += ".top_labels_" + c.get("keep_top_k_labels");
    }
    if (c.containsKey("train_fract")) {
      base += ".train_fract_" + c.get("train_fract");
    }
    if (Defaults.GetValueOrDefault((String) c.get("set_gaussian_kernel_weights"), false)) {
      double sigmaFactor = Double.parseDouble(Defaults.GetValueOrDie(c, "gauss_sigma_factor"));
      base += ".gk_sig_" + sigmaFactor;
    }
   
    if (c.containsKey("algo") && (c.get("algo").equals("adsorption") ||
                                  c.get("algo").equals("mad") ||
                                  c.get("algo").equals("maddl"))) {
      double beta = Defaults.GetValueOrDefault((String) c.get("beta"), 2.0);
      base += ".beta_" + beta;
    }
   
    // if this a post-tune config, then generate seed and test files
    if (Defaults.GetValueOrDefault((String) c.get("is_final_run"), false)) {
      base += ".split_id_" + Defaults.GetValueOrDie(c, "split_id");
    }
   
    return (base);
  }
 
  private static String GetBaseName2(Hashtable c, String idenStr) {
    String base = GetBaseName(c, idenStr) +
      ".mu1_" + c.get("mu1") +
      ".mu2_" + c.get("mu2") +
      ".mu3_" + c.get("mu3") +
      ".norm_" + c.get("norm");
    return (base);
  }
 
  private static void UpdateBestConfig(String algo, TObjectDoubleHashMap algo2BestScore,
                                       HashMap<String,Hashtable> algo2BestConfig, Hashtable config,
                                       ArrayList perIterMultiScores) {
    TDoubleArrayList perIterScores = new TDoubleArrayList();
    for (int i = 1; i < perIterMultiScores.size(); ++i) {
      TObjectDoubleHashMap r = (TObjectDoubleHashMap) perIterMultiScores.get(i);
      perIterScores.add(r.get(Constants.GetMRRString()));
    }

    if (perIterScores.size() > 0) {
      //      System.out.println("SIZE: " + perIterScores.size());
      int mi = 0;
      for (int i = 1; i < perIterScores.size(); ++i) {
        if (perIterScores.get(i) > perIterScores.get(mi)) {
          mi = i;
        }
      }
      //      System.out.println("max_idx: " + mi + " " + perIterScores.toString());
      double maxScore = perIterScores.get(mi); // perIterScores.max();
      if (algo2BestScore.size() == 0 || algo2BestScore.get(algo) < maxScore) {
        //        System.out.println("new best score: " + maxScore);
        // best iteration
        int bestIter = perIterScores.indexOf(maxScore) + 1;

        algo2BestScore.put(algo, maxScore);
        algo2BestConfig.put(algo, (Hashtable) config.clone());
        algo2BestConfig.get(algo).put("iters", bestIter);
      }
    }
  }

  public static void main(String[] args) {
    Hashtable tuningConfig = ConfigReader.read_config(args);
    Run(tuningConfig);
  }
}
TOP

Related Classes of upenn.junto.app.ConfigTuner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.