Package cc.mallet.classify.tui

Source Code of cc.mallet.classify.tui.Vectors2FeatureConstraints

package cc.mallet.classify.tui;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.logging.Logger;

import cc.mallet.classify.FeatureConstraintUtil;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;

/**
* Create "feature constraints" from data for use in GE training.
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*/

public class Vectors2FeatureConstraints {

  private static Logger logger = MalletLogger.getLogger(Vectors2FeatureConstraints.class.getName());
 
  public static CommandOption.File vectorsFile = new
    CommandOption.File(Vectors2FeatureConstraints.class, "input", "FILENAME",
    true, null, "Data file used to generate constraints.", null);
 
  public static CommandOption.File constraintsFile = new
    CommandOption.File(Vectors2FeatureConstraints.class, "output", "FILENAME",
    true, null, "Output file for constraints.", null);
 
  public static CommandOption.File featuresFile = new
    CommandOption.File(Vectors2FeatureConstraints.class, "features-file", "FILENAME",
    false, null, "File with list of features used to generate constraints.", null);
 
  public static CommandOption.File ldaFile = new
    CommandOption.File(Vectors2FeatureConstraints.class, "lda-file", "FILENAME",
    false, null, "File with serialized LDA object (if using LDA feature constraint selection).", null);
 
  public static CommandOption.Integer numConstraints = new
    CommandOption.Integer(Vectors2FeatureConstraints.class, "num-constraints", "FILENAME",
    true, 10, "Number of feature constraints.", null);
 
  public static CommandOption.String featureSelection = new
  CommandOption.String(Vectors2FeatureConstraints.class, "feature-selection", "STRING",
  true, "infogain | lda", "Method used to choose feature constraints.", null);
 
  public static CommandOption.String targets = new
  CommandOption.String(Vectors2FeatureConstraints.class, "targets", "STRING",
  true, "none | oracle | heuristic | voted", "Method used to estimate constraint targets.", null);
 
  public static CommandOption.Double majorityProb = new
  CommandOption.Double(Vectors2FeatureConstraints.class, "majority-prob", "DOUBLE",
      false, 0.9, "Probability for majority labels when using heuristic target estimation.", null);

  public static void main(String[] args) {
    CommandOption.process(Vectors2FeatureConstraints.class, args);
    InstanceList list = InstanceList.load(vectorsFile.value)
   
    // Here we will assume that we use all labeled data available. 
    ArrayList<Integer> features = null;
    HashMap<Integer,ArrayList<Integer>> featuresAndLabels = null;

    // if a features file was specified, then load features from the file
    if (featuresFile.wasInvoked()) {
      if (fileContainsLabels(featuresFile.value)) {
        // better error message from dfrankow@gmail.com
        if (targets.value.equals("oracle")) {
          throw new RuntimeException("with --targets oracle, features file must be unlabeled");
        }
        featuresAndLabels = readFeaturesAndLabelsFromFile(featuresFile.value, list.getDataAlphabet(), list.getTargetAlphabet());
      }
      else {
        features = readFeaturesFromFile(featuresFile.value, list.getDataAlphabet());       
      }
    }
   
    // otherwise select features using specified method
    else {
      if (featureSelection.value.equals("infogain")) {
        features = FeatureConstraintUtil.selectFeaturesByInfoGain(list,numConstraints.value);
      }
      else if (featureSelection.value.equals("lda")) {
        try {
          ObjectInputStream ois = new ObjectInputStream(new FileInputStream(ldaFile.value));
          ParallelTopicModel lda = (ParallelTopicModel)ois.readObject();
          features = FeatureConstraintUtil.selectTopLDAFeatures(numConstraints.value, lda, list.getDataAlphabet());
        }
        catch (Exception e) {
          e.printStackTrace();
        }
      }
      else {
        throw new RuntimeException("Unsupported value for feature selection: " + featureSelection.value);
      }
    }
   
    // If the target method is oracle, then we do not need feature "labels".
    HashMap<Integer,double[]> constraints = null;
   
    if (targets.value.equals("none")) {
      constraints = new HashMap<Integer,double[]>();
      for (int fi : features) {    
        constraints.put(fi, null);
      }
    }
    else if (targets.value.equals("oracle")) {
      constraints = FeatureConstraintUtil.setTargetsUsingData(list, features);
    }
    else {
      // For other methods, we need to get feature labels, as
      // long as they haven't been already loaded from disk.
      if (featuresAndLabels == null) {
        featuresAndLabels = FeatureConstraintUtil.labelFeatures(list,features);
       
        for (int fi : featuresAndLabels.keySet()) {
          logger.info(list.getDataAlphabet().lookupObject(fi) + ":  ");
          for (int li : featuresAndLabels.get(fi)) {
            logger.info(list.getTargetAlphabet().lookupObject(li) + " ");
          }
        }
       
      }
      if (targets.value.equals("heuristic")) {
        constraints = FeatureConstraintUtil.setTargetsUsingHeuristic(featuresAndLabels,list.getTargetAlphabet().size(),majorityProb.value);
      }
      else if (targets.value.equals("voted")) {
        constraints = FeatureConstraintUtil.setTargetsUsingFeatureVoting(featuresAndLabels,list);
      }
      else {
        throw new RuntimeException("Unsupported value for targets: " + targets.value);
      }
    }
    writeConstraints(constraints,constraintsFile.value,list.getDataAlphabet(),list.getTargetAlphabet())
  }
 
  private static boolean fileContainsLabels(File file) {
    String line = "";
    try {
      BufferedReader reader = new BufferedReader(new FileReader(file));
      line = reader.readLine().trim();
    }
    catch (Exception e) { 
      e.printStackTrace();
      System.exit(1);
    }
     
    String[] split = line.split("\\s+");
   
    if (split.length == 1) {
      return false;
    }
    return true;
  }
 
  private static ArrayList<Integer> readFeaturesFromFile(File file, Alphabet dataAlphabet) {
    ArrayList<Integer> features = new ArrayList<Integer>();
   
    try {
      BufferedReader reader = new BufferedReader(new FileReader(file));
     
      String line = reader.readLine();
      while (line != null) {
        line = line.trim();
        int featureIndex = dataAlphabet.lookupIndex(line,false);
        features.add(featureIndex);
        line = reader.readLine();
      }
    }
    catch (Exception e) { 
      e.printStackTrace();
      System.exit(1);
    }
    return features;
  }
 
  public static HashMap<Integer,ArrayList<Integer>> readFeaturesAndLabelsFromFile(File file, Alphabet dataAlphabet, Alphabet targetAlphabet) {
    HashMap<Integer,ArrayList<Integer>> featuresAndLabels = new HashMap<Integer,ArrayList<Integer>>();
   
    try {
      BufferedReader reader = new BufferedReader(new FileReader(file));
     
      String line = reader.readLine();
      while (line != null) {
        line = line.trim();
        String[] split = line.split("\\s+");
        int featureIndex = dataAlphabet.lookupIndex(split[0],false);
        // better error message from dfrankow@gmail.com
        if (featureIndex == -1) {
          throw new RuntimeException("Couldn't find feature '"
            + split[0] + "' in the data alphabet.");
        }
       
        ArrayList<Integer> labels = new ArrayList<Integer>();
        for (int i = 1; i < split.length; i++) {
          // TODO should these be label names?
          int li = targetAlphabet.lookupIndex(split[i]);
          labels.add(li);
          logger.info("found label " + li);
        }
        featuresAndLabels.put(featureIndex,labels);
        line = reader.readLine();
      }
    }
    catch (Exception e) { 
      e.printStackTrace();
      System.exit(1);
    }
    return featuresAndLabels;
  }
 
  private static void writeConstraints(HashMap<Integer,double[]> constraints, File constraintsFile, Alphabet dataAlphabet, Alphabet targetAlphabet) {
   
    if (constraints.size() == 0) {
      logger.warning("No constraints written!");
      return;
    }
   
    try {
      FileWriter writer = new FileWriter(constraintsFile);
      for (int fi : constraints.keySet()) {
        writer.write(dataAlphabet.lookupObject(fi) + " ");
        double[] p = constraints.get(fi);
        if (p != null) {
          for (int li = 0; li < p.length; li++) {
            writer.write(targetAlphabet.lookupObject(li) + ":" + p[li] + " ");
          }
        }
        writer.write("\n");
      }
      writer.close();
    }
    catch (Exception e) {
      e.printStackTrace();
      System.exit(1);
    }
  }
}
TOP

Related Classes of cc.mallet.classify.tui.Vectors2FeatureConstraints

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.